AC 自动机

定义与构建

AC 自动机是解决多模式匹配的一类自动机。

假设现在有若干个字符串 s1,s2,,sks_1,s_2,\ldots,s_k,构建出它们的 Trie。AC 自动机关键在于求出 fail 指针,定义为最长的在 Trie 中出现过的公共前后缀。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// 记得修改字符集大小和字符种类
const int SIGMA = 26;
const char BASE = 'a';
int ch[N][SIGMA],fail[N],cnt[N],tot=1;
void ins(const string& s){
int u=1;
for(int i=0;i<s.size();i++){
int c=s[i]-BASE;
if(!ch[u][c])ch[u][c]=++tot;
u=ch[u][c];
}
cnt[u]++;
}
void get_fail(){
for(int i=0;i<SIGMA;i++)ch[0][i]=1;
fail[1]=0;
queue<int> q;q.push(1);
while(q.size()){
int u=q.front();q.pop();
for(int c=0;c<SIGMA;c++){
int v=ch[u][c];
if(!v)ch[u][c]=ch[fail[u]][c];
else{
fail[v]=ch[fail[u]][c];
q.push(v);
}
}
}
}

应用

字符串匹配

可以发现 fail 指针构成了一棵树,称为 fail 树。

用文本串在 AC 自动机上匹配,换句话说就是寻找一些文本串的子串。

众所周知,子串等于前缀的后缀,于是我们对于文本串每次加入一个字符在 AC 自动机上转移,此时得到的就是文本串的前缀的最长的后缀,使得其为 Trie 树上的某个节点。然后还可能有效的后缀就是这个节点 fail 树上所有的祖先,也就是到根的链。

可以发现,这样的若干条链加起来就是文本串中所有在 Trie 上部分的子串。也就是说,我们求出了文本串所有有效的子串部分。

例题:P5357 【模板】AC 自动机

给你一个文本串 SSnn 个模式串 T1nT_{1∼n},请你分别求出每个模式串 TiT_iSS 中出现的次数。

做法与参考代码

对模式串建出 AC 自动机。

然后我们依次插入文本串的每个字符(也就是对于文本串的每个前缀),现在我们要找这个前缀所对应的后缀,有效的就是 fail 树上到根的部分。

当然不能暴力跳,离线处理即可线性完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include<bits/stdc++.h>
using namespace std;

const int N = 2e5+5;
int ch[N][26],fail[N],ans[N],in[N],tot=1;
int pos[N];

int ins(const string& s){
int u=1;
for(int i=0;i<s.size();i++){
int c=s[i]-'a';
if(!ch[u][c])ch[u][c]=++tot;
u=ch[u][c];
}
return u;
}
void get_fail(){
queue<int> q;
for(int i=0;i<26;i++)ch[0][i]=1;
fail[1]=0;q.push(1);
while(q.size()){
int u=q.front();q.pop();
for(int c=0;c<26;c++){
int& v=ch[u][c];
if(!v)v=ch[fail[u]][c];
else{
fail[v]=ch[fail[u]][c];
in[fail[v]]++;
q.push(ch[u][c]);
}
}
}
}
void query(const string& s){
int u=1;
for(int i=0;i<s.size();i++){
int c=s[i]-'a';
u=ch[u][c];
ans[u]++;
}
}
void topo(){
queue<int> q;
for(int i=1;i<=tot;i++){
if(!in[i])q.push(i);
}

while(q.size()){
int u=q.front();q.pop();
ans[fail[u]]+=ans[u];
in[fail[u]]--;
if(!in[fail[u]])q.push(fail[u]);
}
}


int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n;cin>>n;
for(int i=1;i<=n;i++){
string s;cin>>s;
pos[i] = ins(s);
}

get_fail();

string s; cin>>s;
query(s);
topo();
for(int i=1;i<=n;i++)cout<<ans[pos[i]]<<'\n';
return 0;
}

构建一个满足条件的字符串

有些问题是这样的:找到一个字符串,使得至少包含一个子串/一个子串都不包含。

可以发现,任何一个字符串都等价于在 AC 自动机(是一张有向图)上的一条路径。

根据上文所说的,这条路径上的每个点所能匹配到的子串就是它和它在 fail 树上的所有祖先。

所以可以先做一个类似前缀和的事情,每个点维护到根的信息并,然后就能快速知道这个点拥有的信息是怎么样的,之后就可以应用各类图论算法,从而方便的解决题目。

例题:P2444 [POI 2000] 病毒

给定若干个 01 串,问是否存在一个无限长的 01 串,使得其不包含任意一个给定的串作为子串。

做法与参考代码

先给这些 01 串建立出 01 自动机,然后把每个终点标记为危险的。还需要标记危险的节点还有它们在 fail 树上的后代,那么在求出 fail 指针的时候这是可以顺便完成的。

那么问题就转化成,给定一张有向图和起点,问是否存在一条无限长的路径不经过任何危险节点。把危险节点和它们连的边都删了,这个等价于问从这个起点走能不能碰到环,dfs 即可。(注意,dfs 判环不是遇到访问过的节点就是有环哦,因为是有向图!应该是碰到访问中的节点)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include<bits/stdc++.h>
using namespace std;

const int N = 3e4+5;


int ch[N][2],fail[N],cnt[N],tot=1;
void ins(const string& s){
int u=1;
for(int i=0;i<s.size();i++){
int c=s[i]-'0';
if(!ch[u][c])ch[u][c]=++tot;
u=ch[u][c];
}
cnt[u]++;
}
void get_fail(){
for(int i=0;i<2;i++)ch[0][i]=1;
fail[1]=0;
queue<int> q;q.push(1);
while(q.size()){
int u=q.front();q.pop();
cnt[u]+=cnt[fail[u]]; // 在这里传递信息
for(int c=0;c<2;c++){
int v=ch[u][c];
if(!v)ch[u][c]=ch[fail[u]][c];
else{
fail[v]=ch[fail[u]][c];
q.push(v);
}
}
}
}

int vis[N];
void dfs(int u){
vis[u] = -1;
for(int i=0;i<2;i++){
int v = ch[u][i];
if(vis[v]==-1){
cout << "TAK\n";
exit(0);
}
if(!cnt[v] && !vis[v])dfs(v);
}
vis[u] = 1;
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);
cout<<fixed;

int n;cin>>n;
for(int i=1;i<=n;i++){
string s;cin>>s;
ins(s);
}
get_fail();

dfs(1);
cout << "NIE\n";
return 0;
}

练习

1. CF1482H. Exam
给定 nn 个不同的字符串 S1,S2,,SnS_1, S_2, \dots, S_n,求数对 (i,j)(i,j) 的个数,满足 SiS_iSjS_j 的子串,且不存在一个不等于 iijjkk,满足 SiS_iSkS_k 的子串且 SkS_kSjS_j 的子串。

数据范围:n,S106n, \sum |S| \leq 10^6

做法与参考代码

考虑先对所有串建出 AC 自动机,考虑枚举那个比较长的串,计算哪些比较短的串成为它的贡献。

考虑长串的每个前缀,可以成为答案的只有这个前缀的一个最长的后缀,满足其恰好是某个串。

这样的串不是很多(总共是 O(S)O(\sum|S|) 级别的),我们考虑如何得到一个充要条件判断这些串是否有其他串覆盖它。

首先,用树状数组可以快速计算出每个串出现的次数。如果这个串没有其他串覆盖它,那么这个串每次出现在长串的时候一定都是这个前缀最长的那个后缀,而且不存在其他位置的最长后缀把它覆盖了。

所以我们把每个位置的串区间找出来,如果一个区间被其他区间覆盖了就跳过它。如果一个串的出现次数恰等于它的没被覆盖区间个数,这个串就合法。

所以就做完了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include<bits/stdc++.h>
using namespace std;

#define int long long

const int N = 1e6+5;
int ch[N][26], fail[N], cnt[N], tot = 1;
int lowest[N], realcnt[N];

int dep[N];

void ins(const string& s){
int u = 1;
for(int i=0;i<s.size();i++){
int c = s[i]-'a';
if(!ch[u][c])ch[u][c]=++tot,dep[tot]=dep[u]+1;
u=ch[u][c];
}
cnt[u]++;
}
vector<int> G[N];
void get_fail(){
for(int i=0;i<26;i++)ch[0][i]=1;
fail[1]=0;
queue<int> q;q.push(1);
while(q.size()){
int u = q.front();q.pop();
G[fail[u]].push_back(u);
for(int c=0;c<26;c++){
int v = ch[u][c];
if(!v)ch[u][c]=ch[fail[u]][c];
else{
fail[v]=ch[fail[u]][c];
q.push(v);
}
}
}
}

int dfn[N],sz[N],dd;
void dfs(int u){
dfn[u]=++dd;
sz[u]=1;
for(int v:G[u]){
lowest[v] = cnt[v] ? v : lowest[u];
dfs(v);
sz[u]+=sz[v];
}
}

struct BIT{
int t[N];
int lowbit(int x){return x&-x;}
void add(int i,int k){
for(;i<N;i+=lowbit(i))t[i]+=k;
}
int query(int i){
int r=0;
for(;i;i-=lowbit(i))r+=t[i];
return r;
}
int query(int l,int r){return query(r)-query(l-1);}
}t;

signed main(){
int n; cin>>n;
vector<string> s(n);
for(int i=0;i<n;i++){
cin>>s[i];
ins(s[i]);
}

get_fail();
dfs(1);

int ans = 0;
for(int i=0;i<n;i++){
int now = 1;
vector<tuple<int,int,int>> p;
for(int j=0;j<s[i].size();j++){
int x = s[i][j] - 'a';
now = ch[now][x];

t.add(dfn[now], 1);

int y = (j+1<s[i].size())?lowest[now]:lowest[fail[now]];
if(y)p.push_back({j-dep[y]+1, j,y});
}

ranges::sort(p, std::less<>{}, [](const auto& x){return pair{get<0>(x), -get<1>(x)};});

int maxr = -1;
for(auto [l,r,id]:p){
if(r>maxr){
maxr = r;
realcnt[id]++;
if(t.query(dfn[id], dfn[id]+sz[id]-1) == realcnt[id]){
ans ++;
}
}
}

now = 1;
for(int j=0;j<s[i].size();j++){
int x = s[i][j] - 'a';
now = ch[now][x];

t.add(dfn[now], -1);

int y = (j+1<s[i].size())?lowest[now]:lowest[fail[now]];
realcnt[y]=0;
}
}

cout << ans << '\n';
return 0;
}

后缀数组 (SA)

定义与构建

后缀数组是能求出有关 ss 所有后缀某些信息的一个数据结构,具体来说,可以求出:

  1. sa[i]sa[i]:在 ss 的所有后缀中,排名为 ii 的后缀的开始位置。
  2. rk[i]rk[i]ss 的后缀 s[i:]s[i:] 在所有后缀中的排名。
  3. h[i]h[i]ss 的后缀 s[sa[i]:],s[sa[i+1]:]s[sa[i]:], s[sa[i+1]:] 的最长公共前缀。

构建代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 0-下标,h[i]表示后缀sa[i]和sa[i+1]的LCP长度
struct SA {
vector<int> sa, rk, h;
SA(const string &s) {
int n = s.length();
sa.resize(n); rk.resize(n); h.resize(n - 1);
iota(sa.begin(), sa.end(), 0);
sort(sa.begin(), sa.end(), [&](int a, int b){return s[a]<s[b];});
rk[sa[0]] = 0;
for(int i=1;i<n;++i)
rk[sa[i]]=rk[sa[i-1]]+(s[sa[i]]!=s[sa[i-1]]);
int k = 1;
vector<int> t, c(n); t.reserve(n);
while(rk[sa[n - 1]] < n - 1) {
t.clear();
for(int i=0;i<k;i++) t.push_back(n-k+i);
for(auto i:sa)if(i>=k) t.push_back(i-k);
fill(c.begin(), c.end(), 0);
for(int i=0;i<n;i++) c[rk[i]]++;
for(int i=1;i<n;i++) c[i]+=c[i-1];
for(int i=n-1;i>=0;i--) sa[--c[rk[t[i]]]]=t[i];
swap(rk, t);
rk[sa[0]] = 0;
for (int i=1;i<n;i++){
rk[sa[i]]=rk[sa[i-1]];
if(t[sa[i-1]]<t[sa[i]] || sa[i-1]+k==n || t[sa[i-1]+k]<t[sa[i]+k])
rk[sa[i]]++;
}
k *= 2;
}
for(int i=0,j=0;i<n;i++){
if(rk[i]==0) j=0;
else{
if(j) j--;
while(i+j<n && sa[rk[i]-1]+j<n && s[i+j]==s[sa[rk[i]-1]+j])++j;
h[rk[i]-1] = j;
}
}
}};

应用

求本质不同子串个数

我们还是利用每个子串是某个后缀前缀的性质。

所以可以枚举某个后缀,在减去重复的前缀部分。

考虑按字典序加入每个后缀,每次新增加的前缀正好就是和上一个的 LCP 剩下的前缀。

所以答案可以直接写为 n(n+1)/2h[i]n(n+1)/2 - \sum h[i]

求子串 LCP & 比较子串字典序

先考虑如何求后缀的 LCP。
后缀 s[x:]s[x:]s[y:]s[y:] (xy,rk[x]<rk[y]x\ne y, rk[x]<rk[y]) 的 LCP 长度恰为 minrk[x]i<rk[y]hi\displaystyle \min_{rk[x]\le i < rk[y]}h_i

这个最小值用一个 ST 表维护即可。

那么子串的 LCP 就是求出开始位置后缀的 LCP,再和各自长度取最小值。(注意特判起点相同的情况)

那么知道子串 LCP 之后,比较子串字典序是显然的。(考虑下一个字符是什么就行了,注意特判没有下一个字符的情况)

练习

1. HackerRank - special-substrings

给定一个字符串 ss,要求每次插入一个字符,求出当前字符串所有本质不同回文子串的前缀的个数。

s3×105|s|\le 3\times 10^5

做法与参考代码

易知一个字符串本质不同的回文子串是 O(n)O(n) 级别的,所以可以用 PAM 先求出所有这些回文子串。

至于本质不同的回文前缀个数,仿照 SA 求本质不同子串的方法,把这些回文子串按字典序排序,然后减去相邻的 LCP 长度即可。

那么,用一个 set 去维护所有回文子串(只存端点),然后每次更新前驱后继的贡献即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include<bits/stdc++.h>
using namespace std;

using ll = long long;

const int N = 3e5+5;

struct SA {
vector<int> sa, rk, h;
vector<array<int, 20>> st;
void init(const string &s) {
int n = s.length();
sa.resize(n); rk.resize(n); h.resize(n - 1);
st.resize(n-1);
iota(sa.begin(), sa.end(), 0);
sort(sa.begin(), sa.end(), [&](int a, int b){return s[a]<s[b];});
rk[sa[0]] = 0;
for(int i=1;i<n;++i)
rk[sa[i]]=rk[sa[i-1]]+(s[sa[i]]!=s[sa[i-1]]);
int k = 1;
vector<int> t, c(n); t.reserve(n);
while(rk[sa[n - 1]] < n - 1) {
t.clear();
for(int i=0;i<k;i++) t.push_back(n-k+i);
for(auto i:sa)if(i>=k) t.push_back(i-k);
fill(c.begin(), c.end(), 0);
for(int i=0;i<n;i++) c[rk[i]]++;
for(int i=1;i<n;i++) c[i]+=c[i-1];
for(int i=n-1;i>=0;i--) sa[--c[rk[t[i]]]]=t[i];
swap(rk, t);
rk[sa[0]] = 0;
for (int i=1;i<n;i++){
rk[sa[i]]=rk[sa[i-1]];
if(t[sa[i-1]]<t[sa[i]] || sa[i-1]+k==n || t[sa[i-1]+k]<t[sa[i]+k])
rk[sa[i]]++;
}
k *= 2;
}
for(int i=0,j=0;i<n;i++){
if(rk[i]==0) j=0;
else{
if(j) j--;
while(i+j<n && sa[rk[i]-1]+j<n && s[i+j]==s[sa[rk[i]-1]+j])++j;
h[rk[i]-1] = j;
}
}

for(int i=0;i<n-1;i++){
st[i][0] = h[i];
}

for(int j=1;j<20;j++){
for(int i=0;i+(1<<j)-1 < n-1; i++){
st[i][j] = min(st[i][j-1], st[i+(1<<(j-1))][j-1]);
}
}
}

int lcp(int l1,int r1,int l2,int r2){
if(l1==l2)return min(r1-l1+1,r2-l2+1);

int x = rk[l1];
int y = rk[l2];

if(x>y)swap(x,y);
--y;
int o = __lg(y-x+1);
return min(min(st[x][o], st[y-(1<<o)+1][o]), min(r1-l1+1, r2-l2+1));
}
}sa;

string s;

struct Comp{
bool operator()(const pair<int,int>& lhs,const pair<int,int>& rhs)const{
int x = sa.lcp(lhs.first, lhs.second, rhs.first, rhs.second);
if(x == rhs.second-rhs.first+1) return 0;
if(x == lhs.second-lhs.first+1)return 1;
return s[lhs.first+x] < s[rhs.first+x];
}
};

set<pair<int,int>, Comp> st;

ll ans = 0;

auto calc(const auto& x1,const auto& x2){
return sa.lcp(x1.first, x1.second, x2.first, x2.second);
}

void ins(int l,int r){
pair<int,int> x = {l,r};
auto it = st.lower_bound(x);
if(it != st.end() && it != st.begin()){
ans += calc(*prev(it), *it);
}

ans += r-l+1;

it = st.insert(x).first;

if(next(it) != st.end()) ans -= calc(*it, *next(it));
if(it != st.begin()) ans -= calc(*it, *prev(it));
}

struct PAM{
int fail[N], ch[N][26], len[N], s[N], tot, cnt, lst;
PAM(){
len[0]=0,len[1]=-1,fail[0]=1;
tot=lst=0,cnt=1,s[0]=-1;
}
int get_fail(int x){
while(s[tot-1-len[x]]!=s[tot])x=fail[x];
return x;
}
void insert(char c){
s[++tot]=c-'a';
int p = get_fail(lst);
if(!ch[p][s[tot]]){
len[++cnt]=len[p]+2;
int t = get_fail(fail[p]);
fail[cnt] = ch[t][s[tot]];
ch[p][s[tot]] = cnt;
ins(tot-len[cnt],tot-1);
}
lst = ch[p][s[tot]];
}
}pam;

int main(){
int n; cin>>n;
cin>>s;
sa.init(s);
for(int i=0;i<n;i++){
pam.insert(s[i]);
cout << ans << '\n';
}
return 0;
}

2. QOJ11523 Rikka with New Year’s Party
给定一个字符串 ss,求 ss 「独特」的子串个数。

称两个字符串 s,ts,t 是「独特」的,当且仅当不存在一个双射 f:ΣΣf:\Sigma \to \Sigma,使得 ss 每个字符经过 ff 映射后得到的 s=ts'=t

s105|s|\le 10^5Σ\Sigma 为小写字母。

做法与参考代码

考虑如何处理这个双射意义下本质不同的限制。

考虑把每个字符串的字符映射到上一个相同字符的出现位置(如果没有上一个字符就是 1-1

我们称这种变换为 hh。比如说 h(identity)={1,1,1,1,1,0,4,1}h(\texttt{identity})=\{-1,-1,-1,-1,-1,0,4,-1\}h(babab)={1,1,0,1,2}h(\texttt{babab})=\{-1,-1,0,1,2\}

可以发现两个字符串是独特的当且仅当 h(s)h(t)h(s)\ne h(t)

这样就成功转化了题意,现在只需要求出 card{h(s[l..r])0lr<s}\operatorname{card} \{h(s[l..r]) \mid 0\le l \le r < |s|\} 即可。

我们观察到了一个性质,h(s[l..r])=h(s[l,n1])[1..(rl+1)]h(s[l..r]) = h(s[l,n-1])[1..(r-l+1)]。也就是说一个子串经过 hh 的变化之后得到的结果,等于对这个后缀进行变换,然后取前缀。

然后根据 SA 的思路,我们希望求出所有经过变换后的后缀的字典序排名和相邻排名经过变换后的后缀的 LCP。

如何求出每个后缀经过变换的结果 h(s[l..n1])h(s[l..n-1])?先求出 h(s)h(s),然后可以发现 h(s[l..n1])h(s[l..n-1]) 就是 h(s)[l..n1]h(s)[l..n-1] 把最多 2626 个位置变为 00 的结果,那么把这些位置拉出来单独在求排名和 LCP 时处理就做完了。

复杂度 O(26nlogn)O(26 n\log n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

#include<bits/stdc++.h>
using namespace std;

#define int long long


struct SA {
int n;
vector<int> sa, rk, h;
SA(const vector<int>& s) {
n = s.size();
sa.resize(n);
h.resize(n - 1);
rk.resize(n);
iota(sa.begin(), sa.end(), 0);
sort(sa.begin(), sa.end(), [&](int a, int b) {return s[a] < s[b];});
rk[sa[0]] = 0;
for (int i = 1; i < n; ++i)
rk[sa[i]] = rk[sa[i - 1]] + (s[sa[i]] != s[sa[i - 1]]);
int k = 1;
vector<int> tmp, cnt(n);
tmp.reserve(n);
while (rk[sa[n - 1]] < n - 1) {
tmp.clear();
for (int i=0;i<k;i++)tmp.push_back(n-k+i);
for (auto i:sa)
if(i>=k) tmp.push_back(i-k);
fill(cnt.begin(), cnt.end(), 0);
for (int i=0;i<n;i++)++cnt[rk[i]];
for (int i=1;i<n;i++)cnt[i]+=cnt[i-1];
for (int i=n-1;i>=0;i--)sa[--cnt[rk[tmp[i]]]]=tmp[i];
swap(rk, tmp);
rk[sa[0]] = 0;
for (int i=1;i<n;i++){
rk[sa[i]]=rk[sa[i-1]];
if(tmp[sa[i-1]]<tmp[sa[i]]||sa[i-1]+k==n||tmp[sa[i-1]+k]<tmp[sa[i]+k])
rk[sa[i]]++;
}
k *= 2;
}
for(int i=0,j=0;i<n;i++){
if(rk[i]==0) j=0;
else{
if(j)j--;
while(i+j<n && sa[rk[i]-1]+j<n && s[i+j]==s[sa[rk[i]-1]+j])++j;
h[rk[i]-1] = j;
}
}
}
};

signed main(){
ios::sync_with_stdio(0);cin.tie(0);

int n; cin>>n;
string s; cin>>s;

vector<int> t(s.size());
vector<int> lst(26, -1);
for(int i=0;i<s.size();i++){
if(lst[s[i]-'a']!=-1)t[i]=i-lst[s[i]-'a'];
lst[s[i]-'a']=i;
}

auto suffix_array = SA(t);
const auto& h = suffix_array.h, rk = suffix_array.rk;

vector<vector<int>> id(n, vector<int>(27));

for(int j=0;j<27;j++)id[n-1][j]=n;
id[n-1][s[n-1]-'a']=n-1;

for(int i=n-2;i>=0;i--){
for(int j=0;j<27;j++)id[i][j]=id[i+1][j];
id[i][s[i]-'a']=i;
}

for(int i=0;i<n;i++)sort(id[i].begin(), id[i].end());

vector<vector<int>> st(n-1,vector<int>(26));

for(int i=0;i<n-1;i++)st[i][0]=h[i];
for(int i=1;i<26;i++){
for(int j=0;j+(1<<i)-1<n-1;j++){
st[j][i]=min(st[j][i-1],st[j+(1<<(i-1))][i-1]);
}
}

auto query = [&](int l,int r)->int {
if(l==r)return n-l;
if(l>r)swap(l,r);
--r;
int o = __lg(r-l+1);
return min(st[l][o], st[r-(1<<o)+1][o]);
};

auto lcp = [&](int u,int v){
for(int i=0,j=0;id[u][i]<n || id[v][j]<n;){
if(id[u][i]-u == id[v][j]-v){
if(id[u][i]==n)return n-u;
if(id[v][j]==n)return n-v;
int now = id[u][i++]-u; ++j;
int nxt = min(id[u][i]-u, id[v][j]-v);
++now; if(now==nxt)continue;
if(query(rk[u+now],rk[v+now])>=nxt-now)continue;
return now+query(rk[u+now],rk[v+now]);
}
if(id[u][i]-u > id[v][j]-v){
swap(u,v);
swap(i,j);
}
if(id[u][i]==n)return n-u;
if(v+id[u][i]-u>=n)return n-v;
int now = id[u][i++]-u;
int nxt = min(id[u][i]-u, id[v][j]-v);
if(t[v+now] != 0)return now;
++now; if(now==nxt)continue;
if(query(rk[u+now],rk[v+now])>=nxt-now)continue;
return now+query(rk[u+now],rk[v+now]);
}
return min(n-u, n-v);
};

auto cmp = [&](int u,int v)-> bool {
int x = lcp(u,v);
if(v+x == n)return 0;
if(u+x == n)return 1;
int val_u = find(id[u].begin(), id[u].end(), u+x) == id[u].end() ? t[u+x] : -1;
int val_v = find(id[v].begin(), id[v].end(), v+x) == id[v].end() ? t[v+x] : -1;
return val_u < val_v;
};

vector<int> p(n);
iota(p.begin(), p.end(), 0);
sort(p.begin(), p.end(), cmp);

long long ans = 1ll * n * (n+1)/2;
for(int i=1;i<n;i++){
ans -= lcp(p[i-1],p[i]);
}
cout << ans << '\n';
return 0;
}

后缀自动机 (SAM)

定义与构建

后缀自动机是能接受字符串 ss 所有后缀的最小 DFA。

我们增量构建一个字符串的 SAM。

注意:SAM 的节点编号并不满足拓扑序,因为我们会分裂节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
struct SAM{ // 注意修改字符集!字符集是小写字母吗?
int last = 1, tot = 1;
int ch[N<<1][26], len[N<<1], f[N<<1];
void ins(char c){
c -= 'a';
int p = last, cur = last = ++tot;
len[cur]=len[p]+1;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1;return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone=++tot;
for(int i=0;i<26;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}
}sam; // 注意任何跟 SAM 有关的数组都要开两倍

性质

理解 SAM 有两个视角,一个是自动机视角,另一个是后缀连接树(parent 树)视角。

DAG 视角

从这个视角出发,SAM 是一张 DAG,从根节点出发,每次可以通过一个字符进行转移。

字符串的每个子串与 SAM 上从根节点出发的一条路径一一对应。(因为 SAM 可以接受 ss 的所有后缀,那么 SAM 上的路径就对应着一个后缀的前缀,也就是子串)

后缀连接树视角

从这个视角出发,我们可以看到通过后缀连接,SAM 形成了一棵树。

可以有以下性质:

  • 每个节点包含的是 endpos\textrm{endpos} 相同的所有子串。
  • 每个节点包含的是原串的某个前缀的一段长度连续的后缀。
  • 通过一个节点跳到根,构成的节点中所代表的子串恰好取遍了该节点对应的最长子串的所有后缀。

在代码中,每个节点对应的子串长度为 len[f[u]]+1len[u]

应用

求本质不同子串个数

每个子串都对应了 SAM 中的一个节点(可能多个子串对应同一个节点),而每个节点所对应的子串个数就由 len[u]-len[f[u]] 给出,所以计算所有节点的该数值之和就能求出有多少个本质不同的子串。

(这个问题还可以从 DAG 视角求出,见 OI Wiki 相关页面)

例题:「SDOI2016」生成魔咒
给定一个字符串,每次插入一个字符,每次回答当前字符串的不同子串个数。

做法与参考代码

考虑增量统计,每次的增量就是 SAM 中 last 节点对应的 len[u]-len[f[u]]

(注意到虽然插入过程中有些节点会分裂,但分裂的节点并不会对不同子串的个数产生贡献,所以只需要考虑 last 节点的贡献即可)

这个题的字符集特别大,SAM 理论上是不依赖字符集的,但实现一般用数组,本题就不可接受。我们有用 std::map 实现的 SAM,就能通过本题了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include<bits/stdc++.h>
using namespace std;

const int N = 1e5+5;
struct SAM{
int last = 1, tot = 1;
map<int,int> ch[N<<1];
int len[N<<1], f[N<<1];
vector<int> G[N<<1];
int ins(int c){
int p = last, cur = last = ++tot;
len[cur]=len[p]+1;
for(;p&&!ch[p].count(c);p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1;return len[cur]-len[f[cur]];}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return len[cur]-len[f[cur]];}
int clone=++tot;
ch[clone]=ch[q];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p].count(c)&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
return len[cur]-len[f[cur]];
}
}sam;

int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n; cin>>n;
long long ans = 0;
for(int i=0;i<n;i++){
int x; cin>>x;
ans += sam.ins(x);
cout << ans << '\n';
}

return 0;
}

(困难)例题:P6292 区间本质不同子串个数

给定一个长度为 nn 的仅包含小写字母的字符串 SSmm 次询问由 SS 的第 LL 到第 RR 个字符组成的字符串包含多少个本质不同的非空子串。

n105,m2×105n\le 10^5, m\le 2\times 10^5

做法与参考代码

回想求解区间不同元素个数的做法。
一般有两种做法:

  1. rr 扫描线,用一个数据结构,对于每种元素在最后一次出现的位置有 11 的贡献,查询就是查 [l,r][l,r] 的区间和。
  2. rr 扫描线,用一个数据结构,ll 位置保存 [l,r][l,r] 的答案,查询就是单点查,修改就考虑如何去维护贡献。

一般来说,第一种做法会比较干净一点,我们考虑把这个搬到 SAM 上。

现在还是对 rr 进行扫描线,关心的就是每种本质不同子串的左端点最后一次出现位置。

多增加一个字符所可能带来的新的子串就是 s[1..r]s[1..r] 的所有后缀,也就是后缀树上这个节点到根的所有部分,那么我们应该把这一条链上的所有原来的贡献抵消,然后加入新的贡献。

新的贡献就是一个区间加,对应每个后缀。但对原来贡献的抵消,如果原来是多个不同的 endpos 抵消就比较困难,为了解决这个问题,考虑树剖+势能线段树。线段树每次只当区间内所有颜色都一样的时候才进行操作。

根据势能进行分析,每次线段树递归都会让颜色段数量减少一个,而每次会贡献 O(logn)O(\log n) 的颜色段,所以复杂度正确。

每次树剖跳链的复杂度只与颜色段个数有关,所以复杂度是 O(nlog2n+qlogn)O(n\log^2 n + q\log n) 的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#include<bits/stdc++.h>
using namespace std;

#define int long long

const int N = 2e5+5;

int last = 1, tot = 1;
signed ch[N][26], len[N], f[N];
void ins(char c){
c -= 'a';
int p = last, cur = last = ++tot;
len[cur]=len[p]+1;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1;return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone=++tot;
for(int i=0;i<26;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}

int hson[N], sz[N], dep[N], fa[N], top[N];
int dfn[N], dd, node[N];
vector<int> G[N];

void dfs1(int u,int ff){
sz[u] = 1;
dep[u] = dep[ff] + 1;
fa[u] = ff;
for(int v:G[u]){
if(v==ff)continue;
dfs1(v,u);
sz[u] += sz[v];
if(sz[v] > sz[hson[u]])hson[u] = v;
}
}
void dfs2(int u,int t){
dfn[u] = ++dd;
top[u] = t;
node[dd] = u;
if(hson[u])dfs2(hson[u], t);
for(int v:G[u]){
if(v==fa[u] || v==hson[u])continue;
dfs2(v,v);
}
}


// 区间加,区间查询
struct SegAns{
int t[N*4], tag[N*4];

void upd(int o,int l,int r,int k){
t[o] += (r-l+1) * k;
tag[o] += k;
}
void pushdown(int o,int l,int r){
if(tag[o]){
int mid = (l+r)>>1;
int lch=o<<1,rch=o<<1|1;
upd(lch,l,mid,tag[o]);
upd(rch,mid+1,r,tag[o]);
tag[o] = 0;
}
}
void add(int o,int l,int r,int ql,int qr, int qk){
if(ql<=l && r<=qr){
upd(o, l,r,qk);
return;
}
int mid=(l+r)>>1;
int lch=o<<1,rch=o<<1|1;
pushdown(o,l ,r);
if(ql<=mid)add(lch,l,mid,ql,qr,qk);
if(qr>mid)add(rch,mid+1,r,ql,qr,qk);
t[o] = t[lch] + t[rch];
}
int query(int o,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr)return t[o];
int mid=(l+r)>>1;
int lch=o<<1,rch=o<<1|1;
pushdown(o,l , r);
if(qr<=mid)return query(lch,l,mid,ql,qr);
if(ql>mid)return query(rch,mid+1,r,ql,qr);
return query(lch,l,mid,ql,qr)+query(rch,mid+1,r,ql,qr);
}
}t1;

// 势能线段树
struct SegColor{
int t[N*4], tag[N*4];

void upd(int o,int k){
t[o] = tag[o] = k;
}
void pushdown(int o){
if(tag[o]){
int lch=o<<1,rch=o<<1|1;
upd(lch,tag[o]),upd(rch,tag[o]);
tag[o] = 0;
}
}

void modify(int o,int l,int r,int ql,int qr,int qk){
if(ql<=l && r<=qr && t[o]!=-1){
if(t[o]){
// 抵消这一段的贡献
// endpos 是 t[o]
t1.add(1,1,tot,t[o]-len[node[r]]+1, t[o]-len[fa[node[l]]], -1);
}
upd(o, qk);
return;
}
pushdown(o);
int mid = (l+r)>>1;
int lch = o<<1, rch=o<<1|1;
if(ql<=mid)modify(lch,l,mid,ql,qr,qk);
if(qr>mid)modify(rch,mid+1,r,ql,qr,qk);

t[o] = (t[lch] == t[rch]) ? t[lch] : -1;
}
}t2;

void modify(int u,int k){
while(u > 1){
t2.modify(1,2,tot,max(2ll,dfn[top[u]]), dfn[u], k);
u = fa[top[u]];
}
}

int ans[N];
vector<pair<int,int>> queries[N];

signed main(){
ios::sync_with_stdio(0); cin.tie(0);

string s; cin>>s;
int n = s.size();

for(auto x:s)ins(x);

for(int i=2;i<=tot;i++){
G[f[i]].push_back(i);
G[i].push_back(f[i]);
}

dfs1(1,0);
dfs2(1,1);
int q; cin>>q;
for(int i=1;i<=q;i++){
int l,r; cin>>l>>r;
queries[r].push_back({l, i});
}
int now = 1;
for(int i=1;i<=n;i++){
now = ch[now][s[i-1]-'a'];
modify(now, i);
t1.add(1,1,tot,1,i,1);
for(auto [l,id]:queries[i]){
ans[id] = t1.query(1,1,tot,l,i);
}
}
for(int i=1;i<=q;i++)cout<<ans[i]<<'\n';
return 0;
}

求每个子串的出现次数

我们知道,子串是某个前缀的后缀。

所以从后缀连接树视角,我们可以获取每个前缀在后缀连接树上对应的节点,然后将这个节点到根的路径权值加 1,就求出了这个子串的出现位置。

(同样,注意到这个节点可能会在之后分裂,不过没有关系,因为代码保证了新分裂出的节点是长度比较短的那个,所以是对的,贡献所在的位置是正确的。)

例题:P3804 【模板】后缀自动机(SAM)
给定一个只包含小写字母的字符串 SS

请你求出 SS 的所有出现次数不为 11 的子串的出现次数乘上该子串长度的最大值。

S106|S|\le 10^6

做法与参考代码

用 SAM 的增量构造每次插入一个节点,然后在后缀连接树上把这个节点到根的路径全部加 1,然后遍历每个节点,用 len 乘上出现次数就能统计答案了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<bits/stdc++.h>
using namespace std;

using ll = long long;

const int N = 1e6;
int last = 1, tot = 1;
int ch[N<<1][26], len[N<<1], f[N<<1];
int cnt[N<<1];
vector<int> G[N<<1];
void ins(char c){
c -= 'a';
int p = last, cur = last = ++tot;
len[cur]=len[p]+1;
cnt[cur]=1;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1;return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone=++tot;
for(int i=0;i<26;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}

ll ans = 0;
void dfs(int u){
for(auto v:G[u]){
dfs(v);
cnt[u]+=cnt[v];
}

if(cnt[u]>1){
ans = max(ans, 1ll*cnt[u]*len[u]);
}
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);
string s; cin>>s;
for(int i=0;i<s.size();i++)ins(s[i]);
for(int i=2;i<=tot;i++){
G[f[i]].push_back(i);
}
dfs(1);
cout << ans << '\n';
return 0;
}

在 SAM 上做单模匹配

考虑如下问题,给定一个文本串 ss 和模式串 tt,问 ttss 中出现了几次?

ss 建立 SAM,然后用自动机视角,初始节点为根节点,然后依次加入 tt 的每个字符,尝试通过自动机进行转移,如果没有对应的出边就跳后缀连接树的父亲,如果到根了还转移不出去就把当前的匹配节点改为根。

复杂度 O(s+t)O(|s|+|t|)

例题:P5357 【模板】AC 自动机

给你一个文本串 SSnn 个模式串 T1nT_{1∼n},请你分别求出每个模式串 TiT_iSS 中出现的次数。

做法与参考代码

SS 建出 SAM,同时用后缀连接树求出每个节点的子串个数,然后对于每个 TiT_i 去在 SAM 匹配上即可。

注意到我们可以在线回答询问,这是 AC 自动机做不到的。

参考代码(虽然因为空间原因无法通过本题):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include<bits/stdc++.h>
using namespace std;

const int N = 2e6+5;
struct SAM{ // 注意修改字符集!字符集是小写字母吗?
int last = 1, tot = 1;
int ch[N<<1][26], len[N<<1], f[N<<1];
int cnt[N<<1];
void ins(char c){
c -= 'a';
int p = last, cur = last = ++tot;
len[cur]=len[p]+1;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1;return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone=++tot;
for(int i=0;i<26;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}
vector<int> G[N];
void dfs(int u){
for(int v:G[u]){
dfs(v);
cnt[u] += cnt[v];
}
}
void get_cnt(){
for(int i=2;i<=tot;i++)G[f[i]].push_back(i);
dfs(1);
}
}sam; // 注意任何跟 SAM 有关的数组都要开两倍

int main(){
ios::sync_with_stdio(0);cin.tie(0);
int n; cin>>n;
vector<string> t(n);
for(int i=0;i<n;i++)cin>>t[i];

string s;cin>>s;
for(int i=0;i<s.size();i++){
sam.ins(s[i]);
++sam.cnt[sam.last];
}
sam.get_cnt();

for(int i=0;i<n;i++){
[&](){
int now = 1;
for(auto ch:t[i]){
if(!sam.ch[now][ch-'a']){
cout << 0 << '\n';
return;
}
now = sam.ch[now][ch-'a'];
}
cout << sam.cnt[now] << '\n';
}();
}
return 0;
}

不难发现,这个匹配的性质十分优良,我们介绍两个应用:

  1. 注意到我们可以求出 tt 的每个前缀与 ss 中的子串最长匹配长度,于是可以求出 LCS。

例题: SP1812 LCS2 - Longest Common Substring II

求字符串 s1,s2,,sns_1, s_2, \ldots, s_n 的最长公共子串。

做法与参考代码(匹配角度)

不妨设 s1s_1s1,s2,,sns_1,s_2,\ldots,s_n 中长度最短的字符串。

先对 s1s_1 建立 SAM,然后考虑用 s2,s3,,sns_2,s_3,\ldots,s_n 依次进行匹配。可以发现,在一轮匹配中,我们可以得到 SAM 上一些节点所对应的最长匹配长度,然后再沿着后缀连接树把这个长度转移上去,就得到每个节点与 sks_k 的最长匹配长度。这样把每一轮的长度取一个最小值,就知道了 s1s_1 的每一个子串与 s2,s3,,sns_2,s_3,\ldots,s_n 匹配的结果。

于是就能求出 LCS,复杂度为 O(si)O(\sum |s_i|)。(这是因为 s1s_1 是最短的字符串,每次匹配 sks_k 的复杂度是 O(s1+sk)O(|s_1|+|s_k|)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<bits/stdc++.h>
using namespace std;

const int N = 1e6+5;
int cnt = 0;
struct SAM{
int last = 1, tot = 1;
int ch[N<<1][26], len[N<<1], f[N<<1];
vector<int> G[N<<1];

int ans[N<<1],res[N<<1];
void ins(int c){
c-='a';
int p=last,cur=last=++tot;
len[cur]=len[p]+1;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1; return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone=++tot;
for(int i=0;i<26;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}
void get_fa(){
for(int i=2;i<=tot;i++)G[f[i]].push_back(i);
}

void dfs(int u){
for(int v:G[u]){
dfs(v);
res[u] = max(res[u], min(res[v], len[u]));
}
}
void update_ans(){
dfs(1);
for(int i=1;i<=tot;i++){
ans[i] = min(ans[i], res[i]);
res[i] = 0;
}
}
}sam;

int main(){
ios::sync_with_stdio(0);cin.tie(0);
vector<string> s;
string tmp;
while(cin>>tmp){
s.push_back(tmp);
}

auto it = min_element(s.begin(), s.end(), [](const auto& lhs,const auto& rhs){return lhs.size()<rhs.size();});
swap(*it, *s.begin());

for(auto ch:s[0]){
sam.ins(ch);
}
sam.get_fa();

fill(sam.ans, sam.ans+N*2, (int)1e9);

for(int i=1;i<s.size();i++){
int u=1,l=0;
for(auto c:s[i]){
c -= 'a';
while(u && !sam.ch[u][c])u=sam.f[u],l=sam.len[u];
if(!u) u = 1, l = 0;
else{
u = sam.ch[u][c];
++l;
sam.res[u] = max(sam.res[u], l);
}
}
sam.update_ans();
}
cout << *max_element(sam.ans+1, sam.ans+sam.tot+1) << '\n';
return 0;
}
做法与参考代码(伪广义SAM)

把所有字符串都拼起来,中间用一个分隔符连接。

那么,一个子串在第 ii 个字符串内,对应它是第 ii 个字符串所对应的所有前缀节点,也就是后缀连接树上的若干条链。

那么考虑压位,第 ii 位表示这个节点是否是第 ii 个字符串对应的子串,然后找到是所有字符串子串的节点,输出这些节点的最长长度即可。

复杂度 O(nwsi)O(\frac n w \sum |s_i|)。(O(nw)O(\frac n w) 是压位的复杂度)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include<bits/stdc++.h>
using namespace std;

const int N = 1e6+5;
int cnt = 0;
struct SAM{
int last = 1, tot = 1;
int ch[N<<1][27], len[N<<1], f[N<<1];
vector<int> G[N<<1];

int have[N<<1];
void ins(int c){
c-='a';
int p=last,cur=last=++tot;
len[cur]=len[p]+1;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1; return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone=++tot;
for(int i=0;i<27;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}
void get_fa(){
for(int i=2;i<=tot;i++)G[f[i]].push_back(i);
}

void dfs(int u){
for(int v:G[u]){
dfs(v);
have[u] |= have[v];
}
}

}sam;

int main(){
ios::sync_with_stdio(0);cin.tie(0);
vector<string> s;
string tmp;
int n = 0;
while(cin>>tmp){
s.push_back(tmp);
++n;
}

for(int i=0;i<n;i++){
for(auto ch:s[i]){
sam.ins(ch);
sam.have[sam.last] |= (1<<i);
}
sam.ins('{');
}

sam.get_fa();
sam.dfs(1);

int ans = 0;
for(int i=1;i<=sam.tot;i++){
if(sam.have[i] == (1<<n)-1){
ans = max(ans, sam.len[i]);
}
}
cout << ans << '\n';
return 0;
}
  1. 同理,我们可以做到对于模式串开头弹出字符,结尾插入字符的增量匹配,插入字符已经讨论过了,弹出字符就是减少长度,如果长度太长就往后缀连接树的父亲跳。

例题:CF235C. Cyclical Quest
给定一个字符串 SSnn 询问一个字符串 xix_i,问 SS 中有多少子串和 xix_i 是循环同构的。

S106,n105,xi106|S|\le 10^6, n\le 10^5, \sum |x_i|\le 10^6

做法与参考代码

因为可以在开头弹出字符,我们只需要求出 xx 的每个循环同构串出现了多少次加起来即可。注意不要重复判断(可以考虑打 tag 或用 kmp/哈希 求出循环节)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include<bits/stdc++.h>
using namespace std;

const int N = 1e6+5;

int last = 1, tot = 1;
int ch[N<<1][26], len[N<<1], f[N<<1];
int cnt[N<<1];
void ins(char c,int x){
c -= 'a';
int p = last, cur = last = ++tot;
len[cur]=len[p]+1;
cnt[cur] += x;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1;return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone=++tot;
for(int i=0;i<26;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}
vector<int> G[N<<1];
void dfs(int u){
for(int v:G[u]){
dfs(v);
cnt[u] += cnt[v];
}
}
void get_parent(){
for(int i=2;i<=tot;i++)G[f[i]].push_back(i);
dfs(1);
}

int tag[N<<1];


int main(){
ios::sync_with_stdio(0);cin.tie(0);

string s;cin>>s;
int n;cin>>n;
for(int i=0;i<s.size();i++)ins(s[i],1);

get_parent();

for(int i=1;i<=n;i++){
string t; cin>>t;
ll ans = 0;
int now = 1, l = 0;

auto append = [&](char c){
while(now && !ch[now][c-'a'])now=f[now],l=len[now];
if(!now){
now = 1;
l = 0;
}
else{
now=ch[now][c-'a'];
l++;
}
};
auto pop = [&](){
if(l < t.size())return;
if(len[f[now]]==l-1){
now = f[now];
}
--l;
};
auto calc = [&](){
if(l == t.size()){
if(tag[now] == i)return;
tag[now] = i;
ans += cnt[now];
}
};

for(int j=0;j<t.size();j++)
append(t[j]);
calc();

for(int j=1;j<t.size();j++){
pop();
append(t[j-1]);
calc();
}

cout << ans << '\n';
}
return 0;
}

出现位置查询

给定一个文本串 ss,查询模式串 ppss 中(第一个)出现位置。

对于 ss 建出 SAM,然后找到 pp 对应的节点。那么所有出现位置对应的是后缀连接树上这个节点子树内所有终止节点。于是如果要求第一次出现位置递推即可,求出所有出现位置的话可以暴力搜索,复杂度是 O(ans)O(ans) 的。

练习

1. CF547E Mike and Friends

给定 nn 个字符串 s1,s2,,sns_1,s_2,\ldots,s_nqq 次查询 (l,r,k)(l,r,k),回答 sl,sl+1,,srs_l, s_{l+1}, \ldots, s_r 中,sks_k 出现了几次。

n2×105,q5×105,si2×105n\le 2\times 10^5, q\le 5\times 10^5, \sum|s_i|\le 2\times 10^5

做法与参考代码

首先这个询问可以差分,下面考虑怎么回答 (1,r,k)(1,r,k)

对于 s1,s2,,sns_1,s_2,\ldots,s_n 建立伪广义自动机,中间用分隔符隔开。

考虑对 rr 扫描线,每次会加入若干个节点,需要做若干次链加,然后查询可以预处理出 sks_k 在 SAM 上所对应的节点。

所以需要做的是链加单点查询,可以用树状数组单 log 解决。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include<bits/stdc++.h>
using namespace std;

const int N = 5e5+5;

using ll=long long;

struct SAM{
int last=1,tot=1;
int ch[N<<1][27],len[N<<1],f[N<<1];
void ins(char c){
c-='a';
int p=last,cur=last=++tot;
len[cur]=len[p]+1;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1;return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone=++tot;
for(int i=0;i<27;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q], len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}

vector<int> G[N<<1];
void build_tree(){
for(int i=2;i<=tot;i++)G[f[i]].push_back(i);
}
int dfn[N<<1],sz[N<<1],d;
void dfs(int u){
dfn[u]=++d;sz[u]=1;
for(int v:G[u])dfs(v),sz[u]+=sz[v];
}

ll t[N<<1];
int lowbit(int x){return x&-x;}
void add(int i){
for(;i<=tot;i+=lowbit(i)){
t[i]++;
}
}
ll query(int i){
ll x = 0;
for(;i;i-=lowbit(i))x+=t[i];
return x;
}

void add_path_to_root(int u){add(dfn[u]);}
ll query_node(int u){return query(dfn[u]+sz[u]-1)-query(dfn[u]-1);}
}sam;


string s[N];
vector<int> ver[N];
ll ans[N];

struct Query{
int id;
int str,k;
};
vector<Query> qs[N];
int real_pos[N];

int main(){
ios::sync_with_stdio(0);cin.tie(0);

int n,q; cin>>n>>q;
for(int i=1;i<=n;i++){
cin>>s[i];
for(auto ch:s[i]){
sam.ins(ch);
ver[i].push_back(sam.last);
}
sam.ins('z'+1);
}

sam.build_tree();
sam.dfs(1);

for(int i=1;i<=n;i++){
int u = 1;
for(auto ch:s[i]){
u = sam.ch[u][ch-'a'];
}
real_pos[i] = u;
}

for(int i=1;i<=q;i++){
int l,r,k;cin>>l>>r>>k;
qs[r].push_back({i,k,1});
qs[l-1].push_back({i,k,-1});
}

for(int i=1;i<=n;i++){
for(auto u:ver[i]){
sam.add_path_to_root(u);
}
for(auto [id,str,k]:qs[i]){
ans[id] += k * sam.query_node(real_pos[str]);
}
}

for(int i=1;i<=q;i++)cout<<ans[i]<<'\n';
return 0;
}

2. CF666E. Forensic Examination

给定字符串 ssmm 个字符串 t1,t2,,tmt_1,t_2,\ldots,t_mqq 次询问 (l,r,pl,pr)(l,r,p_l,p_r),回答在 tl,tl+1,,trt_l,t_{l+1},\ldots,t_r 中,哪个字符串中 s[pl,pr]s[p_l, p_r] 出现次数最多,回答下标。若有多个,回答最小的下标。

s,q5×105,ti5×104|s|,q\le 5\times 10^5, \sum |t_i| \le 5\times 10^4

做法与参考代码

t1,t2,,tmt_1,t_2,\ldots,t_m 建立伪广义自动机,中间用分隔符隔开,然后我们希望统计出 SAM 上每个节点所对应的字符串中,在哪个字符串出现最多,于是我们考虑对每一个节点开一个动态开点线段树,维护表示第 tit_i 中出现了几次这个串。

然后在后缀连接树上合并的过程就是线段树合并,直接合并即可。

然后对于每个查询,要找到 s[pl,pr]s[p_l,p_r] 对应的节点。这个可以预处理出 s[1..pr]s[1..p_r] 对应的节点(通过上面写的匹配过程),然后考虑倍增跳 parent 树,跳到那个长度对的节点,记得特判 s[pl,pr]s[p_l, p_r] 根本就没有出现过的情况。

需要把所有询问离线下来挂在点上,然后 dfs 的时候一边合并一边计算答案。

复杂度 O(s+(ti+q)logti)O(|s|+ (\sum|t_i|+ q )\log \sum |t_i|)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include<bits/stdc++.h>
using namespace std;


const int N = 5e5+5;

struct SAM{
int last =1 , tot = 1;
int ch[N<<1][27], len[N<<1], f[N<<1];
void ins(char c){
c -= 'a';
int p = last, cur = last = ++tot;
len[cur] = len[p] + 1;
for(;p&&!ch[p][c];p=f[p])ch[p][c]=cur;
if(!p){f[cur]=1;return;}
int q=ch[p][c];
if(len[q]==len[p]+1){f[cur]=q;return;}
int clone = ++tot;
for(int i=0;i<27;i++)ch[clone][i]=ch[q][i];
f[clone]=f[q],len[clone]=len[p]+1;
f[q]=f[cur]=clone;
for(;p&&ch[p][c]==q;p=f[p])ch[p][c]=clone;
}
}sam;

const int X = 1e7+5;
struct Seg{
int rt[N<<1];
struct Node{
int l,r,v,k;
}t[X];
int tot;



void add(int& o,int l,int r,int qi,int qk){
if(!o)o=++tot;
if(l==r){
t[o].v += qk;
t[o].k = l;
return;
}
int mid = (l+r)>>1;
if(qi<=mid)add(t[o].l,l,mid,qi,qk);
else add(t[o].r,mid+1,r,qi,qk);
t[o].v = max(t[t[o].l].v, t[t[o].r].v);
t[o].k = t[o].v==t[t[o].l].v ? t[t[o].l].k : t[t[o].r].k;
}

int merge(int u,int v,int l,int r){
if(!u || !v)return u+v;
if(l==r){
t[u].v += t[v].v;
return u;
}
int mid = (l+r)>>1;
t[u].l = merge(t[u].l, t[v].l, l, mid);
t[u].r = merge(t[u].r, t[v].r, mid+1, r);
t[u].v = max(t[t[u].l].v, t[t[u].r].v);
t[u].k = t[u].v==t[t[u].l].v ? t[t[u].l].k : t[t[u].r].k;
return u;
};

pair<int,int> merge(auto u,auto v){
if(u.first >= v.first)return u;
return v;
}

pair<int,int> query(int o,int l,int r,int ql,int qr){
if(ql<=l && r<=qr){
return {t[o].v, t[o].k};
}
int mid = (l+r)>>1;
if(qr<=mid)return query(t[o].l,l,mid,ql,qr);
if(ql>mid)return query(t[o].r,mid+1,r,ql,qr);
return merge(query(t[o].l,l,mid,ql,qr), query(t[o].r,mid+1,r,ql,qr));
}
}seg;


int n;

pair<int,int> ans[N<<1];
vector<int> G[N<<1];

vector<tuple<int,int,int>> queries[N];
int fa[N][20];
void dfs1(int u){
fa[u][0] = sam.f[u];
for(int i=1;i<20;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
for(int v:G[u]){
dfs1(v);
}
}
void dfs2(int u){
for(int v:G[u]){
dfs2(v);
seg.rt[u] = seg.merge(seg.rt[u], seg.rt[v],1, n);
}
for(auto [id,l,r]:queries[u]){
ans[id] = seg.query(seg.rt[u],1,n,l,r);
if(!ans[id].first)ans[id].second=l;
}
}

int main(){
ios::sync_with_stdio(0); cin.tie(0);

string s; cin>>s;
s = ' ' + s;
cin>>n;
for(int i=1;i<=n;i++){
string t;cin>>t;
for(auto x:t){
sam.ins(x);
seg.add(seg.rt[sam.last], 1, n, i, 1);
}
sam.ins('{');
}

for(int i=1;i<=sam.tot;i++){
G[sam.f[i]].push_back(i);
}

dfs1(1);

vector<int> match_len(s.size()+1), match_node(s.size()+1);
int now = 1,len = 0;
for(int i=1;i<s.size();i++){
while(now && !sam.ch[now][s[i]-'a'])now=sam.f[now],len=sam.len[now];
if(!now)now=1,len=0;
else{
now = sam.ch[now][s[i]-'a'];
++len;
}

match_len[i] = len;
match_node[i] = now;
}

int q; cin>>q;
for(int k=1;k<=q;k++){
int l,r,pl,pr; cin>>l>>r>>pl>>pr;
if(match_len[pr] < pr-pl+1){
ans[k] = {0, l};
continue;
}

int u = match_node[pr];
int need = pr - pl + 1;
if(!(sam.len[sam.f[u]] < need && need <= sam.len[u])){
for(int i=19;i>=0;i--){
if(sam.len[sam.f[fa[u][i]]] >= need){
u = fa[u][i];
}
}
u = fa[u][0];
}
queries[u].push_back({k,l,r});
}
dfs2(1);
for(int i=1;i<=q;i++)cout<<ans[i].second<<' '<<ans[i].first<<'\n';
return 0;
}

广义后缀自动机

不会。

回文自动机 (PAM)

定义与构建

回文自动机是一种能高效存储字符串所有子串的结构。

具体来说,回文自动机由两颗 Trie 构成,两颗的根分别称为奇根和偶根。

回文自动机上每个除根节点的节点都代表一个回文子串,具体来说是从这个节点出发到根再到这个节点所构成的字符串(如果在奇根所构成的树上,到根的最后一个字符只读一遍)。(或者说,Trie 树上存储的是这个回文子串的后一半)

然后同样也有 fail 指针,定义为这个节点所代表的回文子串的最长回文后缀对应的节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
struct PAM {
// tot: 字符串长度,cnt: 节点个数
int fail[N], ch[N][26], len[N], s[N], tot, cnt, lst;
// fail : 当前节点的最长回文后缀。
// ch : 在当前节点的前后添加字符,得到的回文串。
PAM() {
len[0] = 0, len[1] = -1, fail[0] = 1;
tot = lst = 0, cnt = 1, s[0] = -1;
}
int get_fail(int x) {
while (s[tot-1-len[x]] != s[tot]) x = fail[x];
return x;
}
void insert(char c) {
s[++tot] = c - 'a';
int p = get_fail(lst);
if(!ch[p][s[tot]]) {
len[++cnt] = len[p] + 2;
int t = get_fail(fail[p]);
fail[cnt] = ch[t][s[tot]];
ch[p][s[tot]] = cnt;
}
lst=ch[p][s[tot]];
}
} pam;

性质

本质不同回文子串是 O(n) 的

通过我们的构建方式易得。

例题:P12671 「TFXOI Round 2」String

给定字符串 ssqq 次询问 l1,l2l_1,l_2,找到以下条件的 S1S_1, S2S_2S1S_1 最早出现位置,或报告不存在。

  • S1,S2S_1,S_2ss 的回文子串。
  • S1=l1,S2=l2|S_1| = l_1, |S_2| = l_2
  • S1S_1S2S_2 的前缀且第一个字符在 SS 中出现位置相同。

s,q5×105|s|,q\le 5\times 10^5保证 l1,l2l_1,l_2 随机生成

做法与参考代码

因为我们知道本质不同回文子串级别是 O(n)O(n) 的,所以一个复杂度为 O(cnt(l2))O(\sum \textrm{cnt}(l_2)) 的做法是可以接受的,其中 cnt(l2)\textrm{cnt}(l_2) 表示 ss 中长度为 l2l_2 的本质不同回文子串个数。这个是期望 O(n)O(n) 的。

然后我们枚举每个长度为 l2l_2 的回文串,就是想要知道 fail 树的父亲上有没有长度为 l2l_2 的节点,这个是倍增就做完了。

顺便在 PAM 上记录一下每个回文子串的第一次出现位置。

期望复杂度 O(nlogn)O(n\log n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<bits/stdc++.h>
using namespace std;

const int N = 5e5+5;

struct PAM {
// tot: 字符串长度,cnt: 节点个数
int fail[N], ch[N][26], len[N], s[N], tot, cnt, lst;
int fa[N][20];

int first_pos[N];

// fail : 当前节点的最长回文后缀。
// ch : 在当前节点的前后添加字符,得到的回文串。
PAM() {
len[0] = 0, len[1] = -1, fail[0] = 1;
tot = lst = 0, cnt = 1, s[0] = -1;
}
int get_fail(int x) {
while (s[tot-1-len[x]] != s[tot]) x = fail[x];
return x;
}
void insert(char c) {
s[++tot] = c - 'a';
int p = get_fail(lst);
if(!ch[p][s[tot]]) {
len[++cnt] = len[p] + 2;
first_pos[cnt] = tot - len[cnt] + 1;
int t = get_fail(fail[p]);
fail[cnt] = ch[t][s[tot]];
ch[p][s[tot]] = cnt;
}
lst=ch[p][s[tot]];
}

vector<int> G[N];
vector<int> nodes[N];
void dfs(int u){
nodes[len[u]].push_back(u);
fa[u][0] = fail[u];
for(int i=1;i<20;i++)fa[u][i]=fa[fa[u][i-1]][i-1];

for(int v:G[u]){
dfs(v);
}
}
void pre_work(){
for(int i=2;i<=cnt;i++)G[fail[i]].push_back(i);
dfs(1);dfs(0);
}

int query(int l1,int l2){
int ans = 1e9;
for(auto u:nodes[l2]){
int v = u;
for(int i=19;i>=0;i--){
if(len[fa[u][i]] > l1)u=fa[u][i];
}
u = fa[u][0];
if(len[u] == l1)ans = min(ans, first_pos[v]);
}
return (ans == 1e9) ? -1 : ans;
}
} pam;



int main(){
ios::sync_with_stdio(0); cin.tie(0);
int n,q; cin>>n>>q;
string s; cin>>s;
for(auto x:s)pam.insert(x);

pam.pre_work();

while(q--){
int l,r; cin>>l>>r;
cout << pam.query(l,r) << '\n';
}
return 0;
}

Fail 树到根的路径构成 O(log n) 段等差数列

考察 fail 树上到根的路径,可以证明长度构成了 O(logn)O(\log n) 段等差数列,这是非常优良的性质。

我们可以解决一类 dp 问题,转移方式为每次转移一个回文子串,也就是

dp[i]=f(s[j+1..i] 是回文串g(dp[j]))dp[i] = f\left(\bigodot_{s[j+1..i]\text{ 是回文串}}{g(dp[j])}\right)

其中 \odot 是一种运算(需要满足交换律和结合律),f,gf,g 是某个映射,下面介绍如何解决这个问题。

我们维护 diff[x]\textit{diff}[x] 表示节点 xx 到父亲的长度差,slink[x]\textit{slink}[x] 表示节点 xx 的祖先中第一个 diff\textit{diff} 不同的。

在维护 dp 的时候,设 g[u]g[u] 表示 uu 这个节点所对应的回文子串所在的等差链上需要转移的 dp 值之和。也就是 g[u]=slink[x]=slink[u]g(dp[ilen[x]])g[u] = \bigodot_{\textit{slink}[x]=\textit{slink}[u]} g(dp[i-len[x]]),其中 ii 是最大的下标,满足 uuii 所对应的节点经过若干次 slink\textit{slink} 映射得到(也就是计算 ii 的 dp 值要用到的 uu)。

那么如果我们已经维护好了 gg,根据定义,dp[i]dp[i] 就是 O(logn)O(\log n)g[u]g[u] 的信息合并起来,在应用 ff

我们考虑如何维护这个 g[u]g[u]。先考虑 g[fail[u]]g[fail[u]] 的信息和 g[u]g[u] 的信息有什么差别。fail[u]fail[u] 上一次出现的位置一定是 idiff[u]i-\textit{diff}[u](因为 fail[u]fail[u]uu 的最长回文后缀,然后对比可以发现,g[u]g[u]g[fail[u]]g[fail[u]] 多了一项,就是那个最短的长度所对应的链,根据定义其长度为 len[slink[u]]+dif[u]len[slink[u]]+dif[u],这个就是 dp[ilen[slink[u]]diff[u]]dp[i-len[\textit{slink}[u]]-\textit{diff}[u]]。这样我们就成功维护了 g[u]g[u]

然后对于 slink[u],slink[u],\ldots,维护是同理的。

1
2
3
4
5
6
7
8
9
dp[0] = 0;
for(int i=1;i<s.size();i++){
ins(i);
for(int u=now;u>1;u=slink[u]){
g[u] = dp[i-len[slink[u]]-dif[u]];
if(dif[u]==dif[fail[u]])g[u]=min(g[u],g[fail[u]]);
dp[i] = min(dp[i], g[u]+1);
}
}

应用

区间本质不同回文子串个数

考虑离线按 rr 扫描,数据结构 ll 位置保存的恰好就是 [l..r][l..r] 的答案。

考虑新加入一个字符能带来什么贡献。也就是需要把 [旧的回文串开始位置+1,新的回文串开始位置][\text{旧的回文串开始位置}+1,\text{新的回文串开始位置}] 这一个区间加一。

可以发现,在一个等差数列上的回文子串(不算最长的那个)每个都恰好往后出现了 diff[x]diff[x] 的长度,因此这些区间恰好拼起来了!如下图所示:

然后最长的那个的贡献是可以快速计算的,维护每个回文子串的上一次出现位置即可。所以总共的贡献区间就是 [当前等差数列上最长的回文串的上一次出现位置+1,ilen[slink[u]]diff[u]+1][\text{当前等差数列上最长的回文串的上一次出现位置}+1,i-len[slink[u]]-diff[u]+1]

所以我们只需要做 O(logn)O(\log n) 次区间加即可。复杂度 O(nlog2n)O(n\log^2 n)

参考代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#include<bits/stdc++.h>
using namespace std;

const int N = 1e5+5;


struct BIT{
int t[N];
void clear(){
memset(t, 0, sizeof(t));
}
int lowbit(int x){return x&-x;}
void add(int i,int x){
if(!i)return;
// cerr << "add " << i << " = " << x << endl;
for(;i<N;i+=lowbit(i))t[i]+=x;
}
int query(int i){
int r = 0;
for(;i;i-=lowbit(i))r+=t[i];
return r;
}
}bit;

// 单点修改,区间 max
struct Seg{
int t[N*4];
void build(int o,int l,int r){
t[o] = 0;
if(l==r)return;
int lch=o<<1,rch=o<<1|1;
int mid = (l+r)>>1;
build(lch,l,mid),build(rch,mid+1,r);
}
int query(int o,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr)return t[o];
int lch=o<<1,rch=o<<1|1;
int mid = (l+r)>>1;
if(qr<=mid)return query(lch,l,mid,ql,qr);
if(ql>mid)return query(rch,mid+1,r,ql,qr);
return max(query(lch,l,mid,ql,qr),query(rch,mid+1,r,ql,qr));
}

void modify(int o,int l,int r,int qi,int qk){
if(l==r){
t[o] = qk;
return;
}
int lch=o<<1,rch=o<<1|1;
int mid = (l+r)>>1;
if(qi<=mid)modify(lch,l,mid,qi,qk);
else modify(rch,mid+1,r,qi,qk);
t[o] = max(t[lch], t[rch]);
}
}seg;

struct PAM {
vector<int> G[N];
int fail[N], ch[N][26], len[N], s[N], tot, cnt, lst, dd;
int diff[N], slink[N];
int nodes[N];
int dfn[N], sz[N];
void clear() {
for(int i=0;i<=cnt;i++){
fail[i]=len[i]=diff[i]=slink[i]=dfn[i]=sz[i]=0;
for(int j=0;j<26;j++)ch[i][j]=0;
G[i].clear();
}
for(int i=0;i<=tot;i++)s[i]=0;
dd = 0;
len[0] = 0, len[1] = -1, fail[0] = 1;
tot = lst = 0, cnt = 1, s[0] = -1;
}
int get_fail(int x) {
while (s[tot-1-len[x]] != s[tot]) x = fail[x];
return x;
}
void insert(char c) {
s[++tot] = c - 'a';
int p = get_fail(lst);
if(!ch[p][s[tot]]) {
len[++cnt] = len[p] + 2;
int t = get_fail(fail[p]);
fail[cnt] = ch[t][s[tot]];
diff[cnt] = len[cnt] - len[fail[cnt]];
if(diff[cnt] == diff[fail[cnt]])slink[cnt]=slink[fail[cnt]];
else slink[cnt] = fail[cnt];
ch[p][s[tot]] = cnt;
}
lst=ch[p][s[tot]];
nodes[tot] = lst;
}

void dfs(int u){
dfn[u] = ++dd;
sz[u] = 1;
for(int v:G[u]){
dfs(v);
sz[u] += sz[v];
}
}
void pre_work(){
for(int i=2;i<=cnt;i++)G[fail[i]].push_back(i);
dfs(0);
dfs(1);
}


void update_pos(int u,int i){
seg.modify(1,1,cnt,dfn[u],i);
}
int last_pos(int u){
return seg.query(1,1,cnt,dfn[u],dfn[u]+sz[u]-1);
}
}pam;



vector<pair<int,int>> qs[N];
int ans[N];

void solve(){
string s; cin>>s;
int q; cin>>q;

for(int i=1;i<=s.size();i++)qs[i].clear();

pam.clear();
bit.clear();

for(auto x:s)pam.insert(x);
pam.pre_work();

seg.build(1,1,pam.cnt);


for(int i=1;i<=q;i++){
int l,r;cin>>l>>r;
qs[r].push_back({l,i});
}

int n = s.size();
for(int r=1;r<=n;r++){
// cerr << "R = " << r << endl;
for(int u = pam.nodes[r]; u>1 ;u=pam.slink[u]){
int last_pos = pam.last_pos(u);
if(last_pos) last_pos = last_pos- pam.len[u] + 1;
bit.add(last_pos+1, 1);
bit.add(r-pam.len[pam.slink[u]]-pam.diff[u]+2, -1);
}

for(auto [l, id]:qs[r]){
ans[id] = bit.query(l);
// cerr << "ans of " << id << " is " << ans[id] << '\n';
}

pam.update_pos(pam.nodes[r], r);
}

for(int i=1;i<=q;i++)cout<<ans[i]<<'\n';
}

int main(){
ios::sync_with_stdio(0); cin.tie(0);
int T; cin>>T;
while(T--)solve();
return 0;
}

练习

参考资料