数位 DP

数位 DP 是一种快速求解出在 [1,n][1,n] 满足条件的数个数的方法。其特点有数据范围一般均为 101810^{18} 次方量级(甚至更高)。

对于这种问题,我们是有一套通用的记忆化搜索模板的。

记忆化搜索

数位 DP 有不使用记忆化搜索的递推写法,但是记忆化搜索写法具有好实现的优点,而且一般记忆化搜索是可以做所有数位 DP 的题目的。

我们从一道题目出发来讲解通用的记忆化搜索方法。

示例题目

LOJ#10164. 「一本通 5.3 例 2」数字游戏

给定两个正整数 aabb,求在 [a,b][a,b] 中的所有整数中,有多少个数满足从左到右各位数字成小于等于的关系。

1ab23111\le a\le b\le 2^{31}-1

解法

首先我们需要知道,对于区间 [a,b][a,b] 来说,可以拆分成求 [1,b][1,b][1,a1][1,a-1] 的答案,然后将结果相减,因此后面我们只需考虑求出 [1,n][1,n] 的所有整数中各数码出现个数。(这种前缀和的思想在数位 DP 中是很通用的,大多数题目都可以这么转化)

对于记忆化搜索来说,标准的格式是 f(u,flag,lim,zero)f(u,\textit{flag},\textit{lim},\textit{zero}),其中 uu 代表的是搜索到哪一位,flag\textit{flag} 代表题目所需要的信息,lim\textit{lim} 代表是否贴着 nn 的上限,zero\textit{zero} 代表目前是否全是 00

对于本题来说,flag\textit{flag} 的信息就是上一次填入的数字,这样就可以处理题目各位数字不下降的要求。

对于记忆化搜索的转移,每次枚举当前填的数码(借助 flag,lim\textit{flag},\textit{lim} 的信息确定能填哪些数码),然后将所有 f(u1,,,)f(u-1,\cdot ,\cdot,\cdot) 的结果进行求和返回。

来看具体代码(请阅读注释了解更多信息):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<bits/stdc++.h>
using namespace std;

const int N = 12;
int p[N];
int dp[N][10];

// dfs 函数负责记忆化搜索
// 返回填 u 以下的位,最后一次填的是flag,贴上限为lim,是否填数为zero的结果
int dfs(int u,int flag,bool lim,bool zero){
if(u==0){ // 填完了
return !zero; // 只要不是 0 都满足条件
}

// 只有lim和zero都为false的时候才进行记忆化
if(!lim && !zero && ~dp[u][flag])return dp[u][flag];

// 计算这一位能填的上下界
int L = flag, // 最小能填的是上一次填的数
R = (lim?p[u]:9); // 如果贴着上界,最多只能填到 p[u]

int ans = 0;

for(int i=L;i<=R;i++){
// 这一位填入 i,并计算新的 lim 和 zero
ans += dfs(u-1, i, lim&&(i==p[u]), zero&&!i);
}
if(!lim && !zero)dp[u][flag]=ans;
return ans;
}

// solve 函数负责将 n 拆解,初始化 dp 数组
int solve(int n){
int cnt = 0; // n 的位数
while(n){ // 将数拆入p数组
p[++cnt]=n%10, n/=10;
}
memset(dp, -1, sizeof(dp)); // 注意这里不能初始化成 0,否则复杂度会错
return dfs(cnt,0,1,1); // 一开始贴着上限,填的都是0
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);
int a,b;
while(cin>>a>>b){
cout << solve(b) - solve(a-1) << '\n';
}
return 0;
}

需要注意的是,我们对 lim,zero\textit{lim},\textit{zero} 有一个为 true 的情况没有记忆化保存,这是因为这两个参数有任意一个为 true 的时候要么是一直贴着上界,要么是一直填 00,两种情况都不会被重复访问,所以不记忆化对复杂度也没有影响。(当然,记忆化 44 个所有参数也是正确的)

例题

接下来我们会通过一些例题来讲解记忆化搜索中的一些变式。

例题 1

LOJ#10168. 「一本通 5.3 练习 3」恨 7 不成妻

[L,R][L,R] 中满足一下条件的整数的平方和

  • 不含有数字 77
  • 不是 77 的倍数;
  • 数位和不是 77 的倍数。

1LR10181\le L \le R \le 10^{18},对 109+710^9 +7 取模。


先考虑如何设计本题的 flag\textit{flag},第一个条件很好处理,只需要不枚举 77 就可以了,第二个条件要求我们储存当前的数对 77 的余数,第三个条件要求我们储存当前的数位和对 77 的余数,最后在填完所有数的时候进行判断即可。

因此本体需要两个 flag\textit{flag},分别表示当前填的这些数位所代表的数对 77 的余数和当前填的数位的数位和对 77 的余数。

本题还有一个变化点在于要求的是数的平方和,这个如何处理呢?

考虑我们已经求出了填后面 u1u-1 位的结果,现在第 uu 位上填的是 ii,要求的就是 xAns(u1)(i10u1+x)2\displaystyle\sum_{x\in \textit{Ans}(u-1)} (i\cdot10^{u-1} + x)^2,其中 Ans(u1)\textit{Ans}(u-1) 表示后面 u1u-1 位上合法每个数的集合。

因为有二项式定理,展开我们发现,只需要知道 Ans(u1)\textit{Ans}(u-1) 中元素的 0,1,20,1,2 次方和就能求出这个。(对于求当前的 0,10,1 次方和同理)

因此我们对于每个记忆化搜索,都保存当前答案的 0,1,20,1,2 次方,在更新时用二项式定理维护即可解决求平方和的问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<bits/stdc++.h>
using namespace std;

using ll=long long;
const ll MOD = 1e9+7;


ll fastpow(ll a,ll b){ // 快速幂
ll res = 1;
while(b){
if(b&1)res=(res*a)%MOD;
a=(a*a)%MOD;
b>>=1;
}
return res;
}

struct Node{ // 0,1,2次方和
ll cnt,sum,sqsum;
};

int p[25];
Node dp[25][7][7];
ll pow10_mod7[25]; // 10^x mod 7 的结果

Node dfs(ll u, bool lim, bool zero,ll r1,ll r2){
if(u==0){
if(r1==0 || r2==0)return Node{0,0,0};
else return Node{1,0,0};
}
if(!lim && !zero && ~dp[u][r1][r2].cnt)
return dp[u][r1][r2];

Node ans{0,0,0};
int up=lim?p[u]:9;
for(int i=0;i<=up;i++){
if(i==7)continue;

ll k = fastpow(10,u-1)*i%MOD;
auto res = dfs(u-1,lim&&i==p[u],zero&&i==0,
(r1+i)%7,(r2+pow10_mod7[u-1]*i)%7);
ans.cnt = (ans.cnt+res.cnt)%MOD;

ans.sum = (ans.sum+res.sum)%MOD;
ans.sum = (ans.sum+res.cnt*k%MOD)%MOD;

ans.sqsum = (ans.sqsum+res.sqsum)%MOD;
ans.sqsum = (ans.sqsum+k*2*res.sum)%MOD;
ans.sqsum = (ans.sqsum+k*k%MOD*res.cnt)%MOD;
}

if(!zero && !lim)dp[u][r1][r2]=ans;
return ans;
}

ll solve(ll n){
memset(dp,-1,sizeof(dp));
int cnt = 0;
while(n){
p[++cnt]=n%10;
n/=10;
}
return dfs(cnt,1,1,0,0).sqsum;
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);
pow10_mod7[0]=1;
for(int i=1;i<25;i++){
pow10_mod7[i]=pow10_mod7[i-1]*10%7;
}

int T; cin>>T;
while(T--){
ll a,b; cin>>a>>b;
cout << (solve(b)-solve(a-1)+MOD)%MOD << '\n';
}
return 0;
}

例题 2

ABC317F Nim

给定 N,A1,A2,A3N,A_1,A_2,A_3,求满足以下条件的三元组 (X1,X2,X3)(X_1,X_2,X_3) 个数,对 998244353998244353 取模:

  • i,1XiN\forall i, 1\le X_i \le N
  • i\forall iXiX_iAiA_i 的倍数;
  • X1X2X3=0X_1\oplus X_2 \oplus X_3 = 0

N1018,1Ai10N\le 10^{18},1\le A_i\le 10


本题的关键在于求出三元组的个数。之前的题目都是求出数的个数,事实上,之前的方法完全可以用来做求 nn 元组个数。

我们只需要将 f(u,flag,lim,zero)f(u,\textit{flag},\textit{lim},\textit{zero}) 修改成 f(u,flag1,flag2,flag3,lim1,lim2,lim3,zero1,zero2,zero3)f(u,\textit{flag}_1,\textit{flag}_2,\textit{flag}_3,\textit{lim}_1,\textit{lim}_2,\textit{lim}_3,\textit{zero}_1,\textit{zero}_2,\textit{zero}_3) 即可。

同时在枚举第 uu 位的时候,要枚举的是 33 个数位共 b3b^3 种情况(其中 bb 是进制,本题因为有异或的条件,应该采用 22 进制)

特别需要注意,此时所有参数都要记忆化,因为一个数贴着上限其他数还有可能有很多种可能,不再是只会访问一次的了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include<bits/stdc++.h>
using namespace std;

using ll=long long;
const int MOD = 998244353;
int dp[80][12][12][12][2][2][2][2][2][2]; // 记忆化数组
int p[80]; // p 用来存上限的每一位,从下标1开始储存,下标越大,越高位
int a1,a2,a3;

// u 当前搜索的位数
// r1,2,3 目前 x1,2,3 对 a1,2,3的余数
// l1,2,3 x1,2,3是否贴着上界
int dfs(int u,int r1,int r2,int r3,bool l1,bool l2,bool l3,
bool z1,bool z2,bool z3){
if(u==0){
// 并且都填过了数
// 搜索结束了,那么只要三个余数都为0,这就是一个合法的答案
return !z1 &&!z2&&!z3 && !r1&&!r2&&!r3;
}
// 我们记忆化所有参数
if(dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3]!=-1)
return dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3];

// 三个数的上界
int up1=l1?p[u]:1,up2=l2?p[u]:1,up3=l3?p[u]:1;
int ans = 0;
for(ll i=0;i<=up1;i++){
for(ll j=0;j<=up2;j++){
for(ll k=0;k<=up3;k++){
if((i^j^k)!=0)continue; // 保证异或和为0

// 新的余数是原来的余数加上这一位的贡献模a
int newr1 = ((ll)r1+(i<<(u-1)))%a1;
int newr2 = ((ll)r2+(j<<(u-1)))%a2;
int newr3 = ((ll)r3+(k<<(u-1)))%a3;

// 递归新的
ans = (ans
+dfs(u-1,newr1,newr2,newr3,
l1&&i==p[u],l2&&j==p[u],l3&&k==p[u],
z1&&i==0,z2&&j==0,z3&&k==0))%MOD;
}
}
}
return dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3]=ans; // 保存记忆化
}

int main(){
ios::sync_with_stdio(0);cin.tie(0);
memset(dp,-1,sizeof(dp));

ll n;
cin>>n>>a1>>a2>>a3;
int cnt = 0;

ll x=n;
while(x) p[++cnt]=x%2, x>>=1;
ll ans = dfs(cnt,0,0,0,1,1,1,1,1,1);
cout << ans << '\n';
return 0;
}

练习

参考资料