本篇文章中,存树的方式统一采用链式前向星。
链式前向星实现

1
2
3
4
5
6
7
8
9
10
int head[N];
int nxt[N<<1],to[N<<1];
int cnt=1;
void add_edge(int u,int v){
nxt[cnt] = head[u];
to[cnt]=v;
head[u]=cnt;
cnt++;
}

树链剖分的概念

将一棵树分割为若干条链,将维护路径问题转化为维护若干条链上信息。

我们介绍一种分割方式——重链剖分

重链剖分可以将树上任意一条路径分割为 O(logn)O(\log n) 条连续的链,每条链上 dfs 序连续,因此可以采用线段树等维护路径上的信息。

对于重链剖分,给出一些定义:

  • 重子节点:子结点中子树最大的节点。若有多个最大的,任意其一为重子节点。
  • 轻子节点:剩余的子节点
  • 重边:到重子节点的边
  • 重链:若干条重边首尾相连形成重链。特别的,对于单个子节点,也认为是一条重链。

这样即可将整棵树划分为若干条重链。

重链剖分示意图

实现

重链剖分通过两次 dfs 实现。

第一次 dfs 求出子树大小,父节点,重子节点,深度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int fa[N],sz[N],hson[N],dep[N];
void dfs1(int u){
hson[u]=-1;
sz[u]=1;
for(int e=head[u];e;e=nxt[e]){
int v=to[e];
if(v==fa[u])continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs1(v);
if(hson[u]==-1 || sz[hson[u]]<sz[v]){
hson[u]=v;
}
sz[u]+=sz[v];
}
}

第二次 dfs 求出 dfs 序和每个节点所在重链的深度最浅节点。在第二次dfs中,优先遍历重子节点,即可保证重链上的点 dfs 序连续。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int dfn[N],top[N];int cnt2=0;
void dfs2(int u,int t){
top[u]=t;
dfn[u]=cnt2++;

if(hson[u]==-1)return;
dfs2(hson[u],t); // 优先遍历重子节点

for(int e=head[u];e;e=nxt[e]){
int v=to[e];
if(v==fa[u] || v==hson[u])continue;
dfs2(v,v);
}
}

应用

维护路径

选择深度较大的子节点,往上跳,直到在同一条重链上。

1
2
3
4
5
6
7
8
9
void operate_path(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x,y);
OPERATE(dfn[top[x]],dfn[x]); // 对这一区间执行任意操作
x = fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
OPERATE(dfn[x],dfn[y]); // 跳到同一条重链上了
}

维护子树

显然子树内 dfs 序连续,只需对 根节点根节点+子树大小-1 这一范围操作即可。

1
2
3
void operate_tree(int x){
OPERATE(dfn[x],dfn[x]+sz[x]-1);
}

求 lca

与维护路径操作类似,也是深度较大的节点先往上跳,直到跳到同一条重链上,此时深度较小的节点即为 lca。

1
2
3
4
5
6
7
8
int lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x,y);
x = fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
return x;
}

例题

洛谷P3384 【模板】重链剖分/树链剖分

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#include<bits/stdc++.h>
using namespace std;
#define ll long long

const int N = 1e5+5;

int n,m,r;
ll p;

int h[N],nxt[N*2],to[N*2];
int cnt = 1;

void add_edge(int u,int v){
nxt[cnt]=h[u];
to[cnt]=v;
h[u]=cnt++;
}

ll val[N];

int fa[N],dep[N],sz[N],hson[N];
void dfs1(int u){
hson[u]=-1;
sz[u]=1;
for(int e=h[u];e;e=nxt[e]){
int v = to[e];
if(v==fa[u])continue;
dep[v]=dep[u]+1;
fa[v]=u;
dfs1(v);
sz[u]+=sz[v];
if(hson[u]==-1 || sz[hson[u]]<sz[v]){
hson[u]=v;
}
}
}

int top[N],dfn[N],rnk[N],cnt2=1;
void dfs2(int u,int t){
top[u]=t;
rnk[cnt2]=u;
dfn[u]=cnt2++;
if(hson[u]==-1)return;
dfs2(hson[u],t);
for(int e=h[u];e;e=nxt[e]){
int v = to[e];
if(v==fa[u] || v==hson[u])continue;
dfs2(v,v);
}
}

int ql,qr;ll qk;
ll t[N<<2],lz[N<<2];

void build(int o,int l,int r){
if(l==r){
t[o] = val[rnk[l]];
return;
}
int ls = (o<<1),rs=((o<<1)|1);
int mid = (l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
t[o] = (t[ls]+t[rs])%p;
}
void pushdown(int o,int l,int r){
int ls = (o<<1),rs=((o<<1)|1);
int mid = (l+r)>>1;
t[ls] = (t[ls]+lz[o]*(mid-l+1))%p;
t[rs] = (t[rs]+lz[o]*(r-mid))%p;
lz[ls]=(lz[ls]+lz[o])%p;
lz[rs]=(lz[rs]+lz[o])%p;
lz[o]=0;
}
void add(int o,int l,int r){
if(ql<=l && r<=qr){
t[o] = (t[o]+qk*(r-l+1))%p;
lz[o] = (lz[o]+qk)%p;
return;
}
pushdown(o,l,r);
int ls = (o<<1),rs=((o<<1)|1);
int mid = (l+r)>>1;
if(mid>=ql)add(ls,l,mid);
if(mid<qr)add(rs,mid+1,r);
t[o] = (t[ls]+t[rs])%p;
}
ll query(int o,int l,int r){
if(ql<=l && r<=qr){
return t[o];
}
pushdown(o,l,r);
int ls = (o<<1),rs=((o<<1)|1);
int mid = (l+r)>>1;
ll ans = 0;
if(mid>=ql)ans += query(ls,l,mid);
if(mid<qr)ans += query(rs,mid+1,r);
return ans %p;
}

void add_chain(int x,int y,ll z){
while(top[x]!=top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x,y);
ql = dfn[top[x]],qr=dfn[x],qk=z;
add(1,1,n);
x = fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ql = dfn[x], qr = dfn[y] ,qk=z;
add(1,1,n);
}

ll query_chain(int x,int y){
ll ans = 0;
while(top[x]!=top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x,y);
ql = dfn[top[x]],qr=dfn[x];
ans += query(1,1,n);
x = fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ql = dfn[x], qr = dfn[y];
ans += query(1,1,n);
return ans % p;
}

void add_tree(int x,ll z){
ql = dfn[x],qr=dfn[x]+sz[x]-1,qk=z;
add(1,1,n);
}

ll query_tree(int x){
ql = dfn[x],qr=dfn[x]+sz[x]-1;
return query(1,1,n);
}

int main(){
scanf("%d%d%d%lld",&n,&m,&r,&p);
for(int i=1;i<=n;i++){
scanf("%lld",val+i);
}
for(int i=1;i<n;i++){
int u1,v1;
scanf("%d%d",&u1,&v1);
add_edge(u1,v1);
add_edge(v1,u1);
}

dfs1(r);
dfs2(r,r);

build(1,1,n);

while(m--){
int o;scanf("%d",&o);
if(o==1){
int x,y;ll z;
scanf("%d%d%lld",&x,&y,&z);
add_chain(x,y,z);
}else if(o==2){
int x,y;
scanf("%d%d",&x,&y);
printf("%lld\n",query_chain(x,y));
}else if(o==3){
int x;ll z;
scanf("%d%lld",&x,&z);
add_tree(x,z);
}else{
int x;
scanf("%d",&x);
printf("%lld\n",query_tree(x));
}
}
return 0;
}

洛谷P2146 [NOI2015] 软件包管理器

对于安装操作,即为该节点到根节点置为 1。对于卸载操作,只需将子树置为 0。线段树维护修改即可。

对于改变数量,可以用操作前后线段树根节点值的变化表示。

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include<bits/stdc++.h>
using namespace std;

const int N = 1e5+5;

int head[N];
int nxt[N<<1],to[N<<1];
int cnt=1;
void add_edge(int u,int v){
nxt[cnt] = head[u];
to[cnt]=v;
head[u]=cnt;
cnt++;
}

int fa[N],sz[N],hson[N],dep[N];
void dfs1(int u){
hson[u]=-1;
sz[u]=1;
for(int e=head[u];e;e=nxt[e]){
int v=to[e];
if(v==fa[u])continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs1(v);
if(hson[u]==-1 || sz[hson[u]]<sz[v]){
hson[u]=v;
}
sz[u]+=sz[v];
}
}

int dfn[N],top[N];int cnt2=0;
void dfs2(int u,int t){
top[u]=t;
dfn[u]=cnt2++;

if(hson[u]==-1)return;
dfs2(hson[u],t);

for(int e=head[u];e;e=nxt[e]){
int v=to[e];
if(v==fa[u] || v==hson[u])continue;
dfs2(v,v);
}
}

int t[N<<2],tag[N<<2];
void pushdown(int o,int l,int r){
if(tag[o]!=-1){
int ls = (o<<1), rs = ((o<<1)|1);
int mid = (l+r)>>1;
t[ls]=tag[o]*(mid-l+1);
t[rs]=tag[o]*(r-mid);
tag[ls]=tag[o];
tag[rs]=tag[o];
tag[o]=-1;
}
}
void modify(int o,int l,int r,int ql,int qr,int qk){
// printf("%d %d %d %d %d %d\n",o,l,r,ql,qr,qk);
if(ql<=l && r<=qr){
t[o] = qk * (r-l+1);
tag[o] = qk;
return;
}
pushdown(o,l,r);
int ls = (o<<1), rs = ((o<<1)|1);
int mid = (l+r)>>1;
if(mid>=ql)modify(ls,l,mid,ql,qr,qk);
if(mid<qr)modify(rs,mid+1,r,ql,qr,qk);
t[o] = t[ls]+t[rs];
}

int query(int o,int l,int r,int ql,int qr){
if(ql<=l && r<=qr){
return t[o];
}
pushdown(o,l,r);
int ls = (o<<1), rs = ((o<<1)|1);
int mid = (l+r)>>1;
int ans = 0;
if(mid>=ql)ans+=query(ls,l,mid,ql,qr);
if(mid<qr)ans+=query(rs,mid+1,r,ql,qr);
return ans;
}

int n;

void install(int x){
// 往上跳
int ans = -t[1];
while(top[x]!=0){
// printf("install %d-> %d (%d,%d)\n",x,top[x],dfn[x],dfn[top[x]]);
modify(1,0,n-1,dfn[top[x]],dfn[x],1);
x=fa[top[x]];
}
// printf("install2 %d-> %d (%d,%d)\n",x,top[x],dfn[x],dfn[top[x]]);
modify(1,0,n-1,dfn[top[x]],dfn[x],1);
ans+=t[1];
printf("%d\n",ans);
}

void uninstall(int x){
printf("%d\n",query(1,0,n-1,dfn[x],dfn[x]+sz[x]-1));
modify(1,0,n-1,dfn[x],dfn[x]+sz[x]-1,0);
}

int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int j;scanf("%d",&j);
add_edge(i,j);
add_edge(j,i);
}

dfs1(0);
dfs2(0,0);

memset(tag,-1,sizeof(tag));

int m;scanf("%d",&m);
char op[20];
while(m--){
int x;
scanf("%s %d",op, &x);
if(op[0]=='i'){
install(x);
}
else{
uninstall(x);
}
}
return 0;
}