线段树维护区间最大子段和

区间最大子段和,指的是区间中一段连续的数的和的最大值,形式化的写作 maxl,ri=lrai\max_{l,r}\sum_{i=l}^ra_i

单独求一个区间的最大子段和可以使用动态规划算法在 O(n)O(n) 时间内求出。但是,如果涉及到区间更改,或求一个子区间的区间最大子段和就无法如此做了。

事实上,使用线段树完全可以用来维护最大子段和。

定义

我们给每个节点维护四个值:区间和 sumsum,从左端点开始的最大子段和 lsls,在右端点结束的最大子段和 rsrs,总共的最大子段和 msms

1
2
3
struct Node{
int sum,ls,rs,ms;
};

pushup

我们考虑如何将这四个值由下面的两个区间 lch,rchlch,rch 更新得到(pushup 操作)。

  1. 区间和 sumsum
    sum=sumlch+sumrchsum=sum_{lch}+sum_{rch}
  2. 从左端点开始的最大子段和 lsls
    左端点开始的最大子段和可能是只在左儿子中,或者是左儿子的全部加上右儿子的 lsls
    ls=max(lslch,sumlch+lsrch)ls=\max(ls_{lch},sum_{lch}+ls_{rch})
  3. 在右端点结束的最大子段和 rsrs
    同上。
    rs=max(rsrch,sumrch+rslch)rs=\max(rs{rch},sum_{rch}+rs_{lch})
  4. 总共的最大子段和 msms
    msms 要么是左儿子和右儿子内部的 msms,要么是左右儿子各取一点拼起来。
    ms=max(mslch,msrch,rslch+lsrch)ms=\max(ms_{lch},ms_{rch},rs_{lch}+ls_{rch})

写成代码形式:

1
2
3
4
5
6
7
void pushup(int o){
int lch=o<<1,rch=o<<1|1;
t[o].s=t[lch].s+t[rch].s;
t[o].ls=max(t[lch].ls,t[lch].s+t[rch].ls);
t[o].rs=max(t[rch].rs,t[rch].s+t[lch].rs);
t[o].ms=max({t[lch].ms,t[rch].ms,t[lch].rs+t[rch].ls});
}

区间查询

复杂度 O(logn)O(\log n)。与普通的区间查询不同,在这里我们区间查询需要返回四个值 s,ls,rs,mss,ls,rs,ms,来在递归的时候进行汇总左右两边的结果,汇总过程和 pushup 类似。

单点修改

复杂度 O(logn)O(\log n)。非常直观的操作,不断递归,直到一个节点,最后不断 pushup 回去。

1
2
3
4
5
6
7
8
9
10
11
void modify(int o,int l,int r,int qi,int qk){
if(l==r){
t[o].s=t[o].ls=t[o].rs=t[o].ms=qk;
return;
}
int lch = o<<1, rch=o<<1|1;
int mid=(l+r)>>1;
if(qi<=mid)modify(lch,l,mid,qi,qk);
else modify(rch,mid+1,r,qi,qk);
pushup(o);
}

例题

[CF1692H] Gambling

题意

给定一个序列 a1,a2,,ana_1,a_2,\ldots,a_n,求出 a,l,ra,l,r,使得在 [l,r][l,r] 范围内 aa 出现次数减去其他数个数的数量最大。

解法

考虑对于一个给定的 aa 求出 l,rl,r。这其实就是最大子段和,我们将所有位置上等于 aa 的数置为 11,其他置为 1-1 即可。

用一个 map 储存每个元素的下标,这样每个元素的位置只需要修改两次,总复杂度 O(nlogn)O(n\log n)

求出 l,rl,r 的问题也可以使用线段树维护,但是有一点麻烦。可以先求出正确的 aa,然后跑一次 dp 求出合法的 l,rl,r

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include<bits/stdc++.h>
using namespace std;

using ll=long long;
using ull=unsigned long long;
using pii=pair<int,int>;
#define all(x) x.begin(),x.end()
#define mem0(x) memset(x,0,sizeof(x))
#define YES puts("YES")
#define NO puts("NO")
#define Yes puts("Yes")
#define No puts("No")
#define errorf(...) fprintf(stderr, __VA_ARGS__)
#define endl '\n'

int read(){
int f=1,x=0;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){x=x*10+c-'0';c=getchar();}
return x*f;
}

const int N = (2e5+5)*4;
struct Node{
int s,ls,rs,ms;
};

struct SegementTree{
Node t[N];
void pushup(int o){
int lch=o<<1,rch=o<<1|1;
t[o].s=t[lch].s+t[rch].s;
t[o].ls=max(t[lch].ls,t[lch].s+t[rch].ls);
t[o].rs=max(t[rch].rs,t[rch].s+t[lch].rs);
t[o].ms=max({t[lch].ms,t[rch].ms,t[lch].rs+t[rch].ls});
}

void build(int o,int l,int r){
if(l==r){
t[o].s=t[o].ls=t[o].rs=t[o].ms=-1;
return;
}
int lch = o<<1, rch=o<<1|1;
int mid=(l+r)>>1;
build(lch,l,mid);
build(rch,mid+1,r);
pushup(o);
}

void modify(int o,int l,int r,int qi,int qk){
if(l==r){
t[o].s=t[o].ls=t[o].rs=t[o].ms=qk;
return;
}
int lch = o<<1, rch=o<<1|1;
int mid=(l+r)>>1;
if(qi<=mid)modify(lch,l,mid,qi,qk);
else modify(rch,mid+1,r,qi,qk);
pushup(o);
}
}st;

int a[N];
int dp[N],l[N];

void solve(int kase){
int n=read();
map<int, vector<int>> mp;
st.build(1,1,n);
for(int i=1;i<=n;i++){
a[i]=read();
mp[a[i]].push_back(i);
}
int ans = 0;
int x;

for(auto&[val,v]:mp){
for(int i:v)st.modify(1,1,n,i,1);
if(st.t[1].ms>ans){
ans=st.t[1].ms;
x=val;
}
for(int i:v)st.modify(1,1,n,i,-1);
}

for(int i=1;i<=n;i++){
if(a[i]==x)a[i]=1;
else a[i]=-1;
}

printf("%d ",x);
dp[1]=a[1];
l[1]=1;
for(int i=2;i<=n;i++){
dp[i]=max(dp[i-1]+a[i],a[i]);
if(dp[i]==dp[i-1]+a[i])l[i]=l[i-1];
else l[i]=i;
}

int ansr = max_element(dp+1,dp+1+n)-dp;
printf("%d %d\n",l[ansr],ansr);
}

int main(){
int T=read(),TT=1;
while(T--){
solve(TT++);
}
return 0;
}