线段树维护区间最大子段和
区间最大子段和,指的是区间中一段连续的数的和的最大值,形式化的写作 max l , r ∑ i = l r a i \max_{l,r}\sum_{i=l}^ra_i max l , r ∑ i = l r a i 。
单独求一个区间的最大子段和可以使用动态规划算法在 O ( n ) O(n) O ( n ) 时间内求出。但是,如果涉及到区间更改,或求一个子区间的区间最大子段和就无法如此做了。
事实上,使用线段树完全可以用来维护最大子段和。
定义
我们给每个节点维护四个值:区间和 s u m sum s u m ,从左端点开始的最大子段和 l s ls l s ,在右端点结束的最大子段和 r s rs rs ,总共的最大子段和 m s ms m s 。
1 2 3 struct Node { int sum,ls,rs,ms; };
pushup
我们考虑如何将这四个值由下面的两个区间 l c h , r c h lch,rch l c h , rc h 更新得到(pushup 操作)。
区间和 s u m sum s u m
s u m = s u m l c h + s u m r c h sum=sum_{lch}+sum_{rch} s u m = s u m l c h + s u m rc h 。
从左端点开始的最大子段和 l s ls l s
左端点开始的最大子段和可能是只在左儿子中,或者是左儿子的全部加上右儿子的 l s ls l s 。
l s = max ( l s l c h , s u m l c h + l s r c h ) ls=\max(ls_{lch},sum_{lch}+ls_{rch}) l s = max ( l s l c h , s u m l c h + l s rc h ) 。
在右端点结束的最大子段和 r s rs rs
同上。
r s = max ( r s r c h , s u m r c h + r s l c h ) rs=\max(rs{rch},sum_{rch}+rs_{lch}) rs = max ( rs rc h , s u m rc h + r s l c h ) 。
总共的最大子段和 m s ms m s
m s ms m s 要么是左儿子和右儿子内部的 m s ms m s ,要么是左右儿子各取一点拼起来。
m s = max ( m s l c h , m s r c h , r s l c h + l s r c h ) ms=\max(ms_{lch},ms_{rch},rs_{lch}+ls_{rch}) m s = max ( m s l c h , m s rc h , r s l c h + l s rc h ) 。
写成代码形式:
1 2 3 4 5 6 7 void pushup (int o) { int lch=o<<1 ,rch=o<<1 |1 ; t[o].s=t[lch].s+t[rch].s; t[o].ls=max (t[lch].ls,t[lch].s+t[rch].ls); t[o].rs=max (t[rch].rs,t[rch].s+t[lch].rs); t[o].ms=max ({t[lch].ms,t[rch].ms,t[lch].rs+t[rch].ls}); }
区间查询
复杂度 O ( log n ) O(\log n) O ( log n ) 。与普通的区间查询不同,在这里我们区间查询需要返回四个值 s , l s , r s , m s s,ls,rs,ms s , l s , rs , m s ,来在递归的时候进行汇总左右两边的结果,汇总过程和 pushup 类似。
单点修改
复杂度 O ( log n ) O(\log n) O ( log n ) 。非常直观的操作,不断递归,直到一个节点,最后不断 pushup 回去。
1 2 3 4 5 6 7 8 9 10 11 void modify (int o,int l,int r,int qi,int qk) { if (l==r){ t[o].s=t[o].ls=t[o].rs=t[o].ms=qk; return ; } int lch = o<<1 , rch=o<<1 |1 ; int mid=(l+r)>>1 ; if (qi<=mid)modify (lch,l,mid,qi,qk); else modify (rch,mid+1 ,r,qi,qk); pushup (o); }
例题
[CF1692H] Gambling
题意
给定一个序列 a 1 , a 2 , … , a n a_1,a_2,\ldots,a_n a 1 , a 2 , … , a n ,求出 a , l , r a,l,r a , l , r ,使得在 [ l , r ] [l,r] [ l , r ] 范围内 a a a 出现次数减去其他数个数的数量最大。
解法
考虑对于一个给定的 a a a 求出 l , r l,r l , r 。这其实就是最大子段和,我们将所有位置上等于 a a a 的数置为 1 1 1 ,其他置为 − 1 -1 − 1 即可。
用一个 map 储存每个元素的下标,这样每个元素的位置只需要修改两次,总复杂度 O ( n log n ) O(n\log n) O ( n log n ) 。
求出 l , r l,r l , r 的问题也可以使用线段树维护,但是有一点麻烦。可以先求出正确的 a a a ,然后跑一次 dp 求出合法的 l , r l,r l , r 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 #include <bits/stdc++.h> using namespace std;using ll=long long ;using ull=unsigned long long ;using pii=pair<int ,int >;#define all(x) x.begin(),x.end() #define mem0(x) memset(x,0,sizeof(x)) #define YES puts("YES" ) #define NO puts("NO" ) #define Yes puts("Yes" ) #define No puts("No" ) #define errorf(...) fprintf(stderr, __VA_ARGS__) #define endl '\n' int read () { int f=1 ,x=0 ;char c=getchar (); while (!isdigit (c)){if (c=='-' )f=-1 ;c=getchar ();} while (isdigit (c)){x=x*10 +c-'0' ;c=getchar ();} return x*f; }const int N = (2e5 +5 )*4 ;struct Node { int s,ls,rs,ms; };struct SegementTree { Node t[N]; void pushup (int o) { int lch=o<<1 ,rch=o<<1 |1 ; t[o].s=t[lch].s+t[rch].s; t[o].ls=max (t[lch].ls,t[lch].s+t[rch].ls); t[o].rs=max (t[rch].rs,t[rch].s+t[lch].rs); t[o].ms=max ({t[lch].ms,t[rch].ms,t[lch].rs+t[rch].ls}); } void build (int o,int l,int r) { if (l==r){ t[o].s=t[o].ls=t[o].rs=t[o].ms=-1 ; return ; } int lch = o<<1 , rch=o<<1 |1 ; int mid=(l+r)>>1 ; build (lch,l,mid); build (rch,mid+1 ,r); pushup (o); } void modify (int o,int l,int r,int qi,int qk) { if (l==r){ t[o].s=t[o].ls=t[o].rs=t[o].ms=qk; return ; } int lch = o<<1 , rch=o<<1 |1 ; int mid=(l+r)>>1 ; if (qi<=mid)modify (lch,l,mid,qi,qk); else modify (rch,mid+1 ,r,qi,qk); pushup (o); } }st;int a[N];int dp[N],l[N];void solve (int kase) { int n=read (); map<int , vector<int >> mp; st.build (1 ,1 ,n); for (int i=1 ;i<=n;i++){ a[i]=read (); mp[a[i]].push_back (i); } int ans = 0 ; int x; for (auto &[val,v]:mp){ for (int i:v)st.modify (1 ,1 ,n,i,1 ); if (st.t[1 ].ms>ans){ ans=st.t[1 ].ms; x=val; } for (int i:v)st.modify (1 ,1 ,n,i,-1 ); } for (int i=1 ;i<=n;i++){ if (a[i]==x)a[i]=1 ; else a[i]=-1 ; } printf ("%d " ,x); dp[1 ]=a[1 ]; l[1 ]=1 ; for (int i=2 ;i<=n;i++){ dp[i]=max (dp[i-1 ]+a[i],a[i]); if (dp[i]==dp[i-1 ]+a[i])l[i]=l[i-1 ]; else l[i]=i; } int ansr = max_element (dp+1 ,dp+1 +n)-dp; printf ("%d %d\n" ,l[ansr],ansr); }int main () { int T=read (),TT=1 ; while (T--){ solve (TT++); } return 0 ; }