最近公共祖先

最近公共祖先(LCA)指的是两个或多个点的公共祖先中离根最远的一个。

对于 LCA 问题,有多种算法可以求解。

倍增

过程

先将树进行一遍 DFS,预处理出 fau,i\textit{fa}_{u,i} 表示节点 uu2i2^i 级祖先。

接下来对于两个节点 u,vu,v,先根据两个节点的深度之差 depudepv|\textit{dep}_u - \textit{dep}_v|,将较深的节点用 fa\textit{fa} 数组,通过 O(logn)O(\log n) 次操作跳到一样深。

当两个节点一样深之后,若此时两个节点相同,直接返回。

否则需要找到最小的 dd,使得 uuvvdd 级祖先相同。为次,我们从最大的 ii 开始,依次递减到 00,每次如果 fau,ifav,i\textit{fa}_{u,i}\ne \textit{fa}_{v,i},就将 ufau,i,vfav,iu\gets \textit{fa}_{u,i}, v\gets \textit{fa}_{v,i},则最后 fau,0=fav,0=LCA(u,v)\textit{fa}_{u,0}=\textit{fa}_{v,0}=\operatorname{LCA}(u,v)

这个算法的正确性在于,我们相当于是寻找一个最大的 kk,使得 uuvvkk 级祖先不同。对于 kmaxk_{max} 的每一个二进制位,都是可以成功向上跳的,又我们向上跳的级数一定 kmax\le k_{max},所以循环结束之后 uuvv 都恰好被设置为其 kmaxk_{max} 级祖先,那么它们的父亲就恰好是两节点的 LCA。

复杂度

预处理 O(nlogn)O(n\log n),单次查询 O(logn)O(\log n)

代码

P3379 【模板】最近公共祖先(LCA)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include<bits/stdc++.h>
using namespace std;

const int N = 5e5+5;
int fa[N][20],dep[N];
vector<int> G[N];
void dfs(int u,int f){
dep[u]=dep[f]+1;
fa[u][0]=f;
// 递推求出fa[u][i]
for(int i=1;i<20;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
for(int v:G[u]){
if(v==f)continue;
dfs(v,u);
}
}

int lca(int u,int v){
if(dep[u]<dep[v])swap(u,v);
// 将两个节点高度设置为相同
while(dep[u]!=dep[v])u=fa[u][__lg(dep[u]-dep[v])];
if(u==v)return u;
// 向上跳尽量多,使得两个节点依然不同
for(int i=19;i>=0;i--){
if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
}
// 此时 u,v 的父亲就是 LCA
return fa[u][0];
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n,m,rt; cin>>n>>m>>rt;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);G[v].push_back(u);
}

dfs(rt,rt);

while(m--){
int u,v;cin>>u>>v;
cout<<lca(u,v)<<'\n';
}

return 0;
}

树链剖分

过程

详见 树链剖分

对于两个节点 u,vu,v,将所在重链顶部节点深度较深的往上跳,直至跳到同一条重链上,此时深度较浅的节点就是 LCA。

复杂度

预处理 O(n)O(n),查询 O(logn)O(\log n)

代码

P3379 【模板】最近公共祖先(LCA)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include<bits/stdc++.h>
using namespace std;

const int N = 5e5+5;
int fa[N],sz[N],dep[N],hson[N],top[N];
vector<int> G[N];
void dfs1(int u,int f){
fa[u]=f,sz[u]=1,dep[u]=dep[f]+1;
for(int v:G[u]){
if(v==f)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[hson[u]])hson[u]=v;
}
}
void dfs2(int u,int t){
top[u]=t;
if(hson[u])dfs2(hson[u],t);
for(int v:G[u]){
if(v==fa[u] || v==hson[u])continue;
dfs2(v,v);
}
}
int lca(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]])swap(u,v);
u=fa[top[u]];
}
if(dep[u]>dep[v])return v;
return u;
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n,m,rt; cin>>n>>m>>rt;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);G[v].push_back(u);
}

dfs1(rt,rt);
dfs2(rt,rt);

while(m--){
int u,v;cin>>u>>v;
cout<<lca(u,v)<<'\n';
}

return 0;
}

欧拉序 + ST表

过程

先求出遍历的欧拉序,记录每个节点第一次出现的位置。

欧拉序:
进行 DFS 过程时,每次进入一个节点或回溯回一个节点就加入一个数组末尾,最后得到的就是欧拉序。

显然,节点数为 nn 的树欧拉序长度为 2n12n-1

则两个节点的 LCA 就是欧拉序中两个节点第一次出现位置之间这一段区间中深度最小的节点。

这个算法的正确性在于,从一个节点到另一个节点,一定会经过它们的 LCA,而且不会经过 LCA 的父亲,所以直接求出深度最小的节点就可以了。

用 ST 表维护区间深度最小的节点。

复杂度

预处理 O(nlogn)O(n\log n),查询 O(1)O(1)

有两倍常数。

代码

P3379 【模板】最近公共祖先(LCA)

实现时注意数组要开两倍大。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include<bits/stdc++.h>
using namespace std;

const int N = 5e5+5;

int dep[N];
int e[N*2],pos[N],tot,st[N*2][20];
vector<int> G[N];

void dfs(int u,int fa){
e[++tot]=u; pos[u]=tot;
dep[u]=dep[fa]+1;
for(int v:G[u]){
if(v==fa)continue;
dfs(v,u);
e[++tot]=u;
}
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n,m,rt; cin>>n>>m>>rt;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);G[v].push_back(u);
}

dfs(rt,rt);

for(int i=1;i<=2*n-1;i++)st[i][0]=e[i];


for(int j=1;j<20;j++){
for(int i=1;(i+(1<<j))-1<=2*n-1;i++){
st[i][j] = (dep[st[i][j-1]] < dep[st[i+(1<<(j-1))][j-1]])
? st[i][j-1]
: st[i+(1<<(j-1))][j-1];
}
}

while(m--){
int u,v; cin>>u>>v;
int l=pos[u],r=pos[v];
if(l>r)swap(l,r);
int o = __lg(r-l+1);
if(dep[st[l][o]]<dep[st[r-(1<<o)+1][o]])cout<<st[l][o]<<'\n';
else cout<<st[r-(1<<o)+1][o]<<'\n';
}

return 0;
}

DFS 序 + ST表

过程

先求出遍历的 DFS 序和时间戳 dfn\textit{dfn}

对于求出 u,vu,v 的 LCA 来说(不妨设 dfnudfnv\textit{dfn}_u\le \textit{dfn}_v),u=vu=v 显然,否则在 DFS 序中 [dfnu+1,dfnv][\textit{dfn}_u+1,\textit{dfn}_v] 区间最浅的节点的父亲就是 LCA。

下面证明这个算法是正确的,设 LCA(u,v)=d\operatorname{LCA}(u,v) = d

可以发现 DFS 序的 [dfnu+1,dfnv][\textit{dfn}_u+1,\textit{dfn}_v] 区间中节点一定全部在 dd 的子树中(不含 dd),又考察 dvd\to v 路径上从 dd 出发下一个节点 vv' 一定在 [dfnu+1,dfnv][\textit{dfn}_u+1,\textit{dfn}_v] 区间中,所以深度最浅的节点一定是 dd 的儿子,符合题意。

复杂度

预处理 O(nlogn)O(n\log n),查询 O(1)O(1)

代码

P3379 【模板】最近公共祖先(LCA)

实现方面可以在 ST 表初始化的时候就储存父亲而不是该节点,这样就无需用数组记录父亲了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include<bits/stdc++.h>
using namespace std;

const int N = 5e5+5;

int dep[N];
int dfn[N], tot, st[N][20];
vector<int> G[N];

void dfs(int u,int fa){
st[++tot][0]=fa; // 直接记录父亲
dfn[u]=tot;
dep[u]=dep[fa]+1;
for(int v:G[u]){
if(v==fa)continue;
dfs(v,u);
}
}

int get(int u,int v){
return dep[u]<dep[v]?u:v;
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n,m,rt; cin>>n>>m>>rt;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);G[v].push_back(u);
}

dfs(rt,rt);

for(int j=1;j<20;j++){
for(int i=1;(i+(1<<j))-1<=n;i++){
st[i][j] = get(st[i][j-1],st[i+(1<<(j-1))][j-1]);
}
}

while(m--){
int u,v; cin>>u>>v;
if(u==v){
cout<<u<<'\n';continue;
}

int l=min(dfn[u],dfn[v])+1,r=max(dfn[u],dfn[v]);
int o = __lg(r-l+1);
cout<< get(st[l][o],st[r-(1<<o)+1][o]) <<'\n';
}

return 0;
}

Tarjan

过程

Tarjan 是一种离线,利用 DFS 和并查集来求解 LCA 的算法。

具体过程如下:

  • 初始化 nn 个集合。
  • DFS 访问到节点 uu 时,
    1. 对于 uu 的每个儿子 vv,对 vv DFS,并在回溯后将 vv 所在集合代表元素设为 uu
    2. 对于每个含有 uu 的查询,若另一个节点 vv 已经被访问过,则 vv 所在集合的代表元素就是 uuvv 的 LCA。

这个算法的正确性在于,我们先访问完 uu 后,会回溯到它们的 LCA 再往下到 vv,所以此时 uu 所在集合的代表元素就恰好是它们的 LCA。

复杂度

离线算法。

复杂度 O(n+mα(n))O(n+m\alpha(n)),其中 nn 为节点数,mm 为询问数。

代码

P3379 【模板】最近公共祖先(LCA)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include<bits/stdc++.h>
using namespace std;

const int N = 5e5+5;


vector<int> G[N];
int fa[N];
int ans[N];
int find(int x){return x==fa[x]?x:fa[x]=find(fa[x]);}
void unite(int i,int j){fa[find(i)]=find(j);}


vector<pair<int,int>> queries[N];
bool vis[N];

void dfs(int u,int f){
vis[u]=1;
for(int v:G[u]){
if(v==f)continue;
dfs(v,u);
unite(v,u);
}
for(auto [v,id]:queries[u]){
if(vis[v])ans[id]=find(v);
}
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n,m,rt; cin>>n>>m>>rt;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);G[v].push_back(u);
}

for(int i=1;i<=n;i++)fa[i]=i;

for(int i=1;i<=m;i++){
int u,v; cin>>u>>v;
queries[u].push_back({v,i});
queries[v].push_back({u,i});
}

dfs(rt,rt);

for(int i=1;i<=m;i++)cout<<ans[i]<<'\n';
return 0;
}

总结

算法 倍增 树链剖分 欧拉序 + ST表 DFS 序 + ST表 Tarjan
复杂度 预处理 O(nlogn)O(n\log n),查询 O(logn)O(\log n) 预处理 O(n)O(n),查询 O(logn)O(\log n) 预处理 O(nlogn)O(n\log n),查询 O(1)O(1) 预处理 O(nlogn)O(n\log n),查询 O(1)O(1) O(n+mα(n))O(n+m\alpha(n))
备注 常数较小 有两倍常数,有 O(n)O(1)O(n)\sim O(1)实现 O(n)O(1)O(n)\sim O(1)实现 离线算法