数位 DP
数位 DP 是一种快速求解出在 [ 1 , n ] [1,n] [ 1 , n ] 满足条件的数个数的方法。其特点有数据范围一般均为 1 0 18 10^{18} 1 0 18 次方量级(甚至更高)。
对于这种问题,我们是有一套通用的记忆化搜索模板的。
记忆化搜索
数位 DP 有不使用记忆化搜索的递推写法,但是记忆化搜索写法具有好实现的优点,而且一般记忆化搜索是可以做所有数位 DP 的题目的。
我们从一道题目出发来讲解通用的记忆化搜索方法。
示例题目
LOJ#10164. 「一本通 5.3 例 2」数字游戏
给定两个正整数 a a a 和 b b b ,求在 [ a , b ] [a,b] [ a , b ] 中的所有整数中,有多少个数满足从左到右各位数字成小于等于的关系。
1 ≤ a ≤ b ≤ 2 31 − 1 1\le a\le b\le 2^{31}-1 1 ≤ a ≤ b ≤ 2 31 − 1 。
解法
首先我们需要知道,对于区间 [ a , b ] [a,b] [ a , b ] 来说,可以拆分成求 [ 1 , b ] [1,b] [ 1 , b ] 和 [ 1 , a − 1 ] [1,a-1] [ 1 , a − 1 ] 的答案,然后将结果相减,因此后面我们只需考虑求出 [ 1 , n ] [1,n] [ 1 , n ] 的所有整数中各数码出现个数。(这种前缀和的思想在数位 DP 中是很通用的,大多数题目都可以这么转化)
对于记忆化搜索来说,标准的格式是 f ( u , flag , lim , zero ) f(u,\textit{flag},\textit{lim},\textit{zero}) f ( u , flag , lim , zero ) ,其中 u u u 代表的是搜索到哪一位,flag \textit{flag} flag 代表题目所需要的信息,lim \textit{lim} lim 代表是否贴着 n n n 的上限,zero \textit{zero} zero 代表目前是否全是 0 0 0 。
对于本题来说,flag \textit{flag} flag 的信息就是上一次填入的数字,这样就可以处理题目各位数字不下降的要求。
对于记忆化搜索的转移,每次枚举当前填的数码(借助 flag , lim \textit{flag},\textit{lim} flag , lim 的信息确定能填哪些数码),然后将所有 f ( u − 1 , ⋅ , ⋅ , ⋅ ) f(u-1,\cdot ,\cdot,\cdot) f ( u − 1 , ⋅ , ⋅ , ⋅ ) 的结果进行求和返回。
来看具体代码(请阅读注释了解更多信息):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 #include <bits/stdc++.h> using namespace std;const int N = 12 ;int p[N];int dp[N][10 ];int dfs (int u,int flag,bool lim,bool zero) { if (u==0 ){ return !zero; } if (!lim && !zero && ~dp[u][flag])return dp[u][flag]; int L = flag, R = (lim?p[u]:9 ); int ans = 0 ; for (int i=L;i<=R;i++){ ans += dfs (u-1 , i, lim&&(i==p[u]), zero&&!i); } if (!lim && !zero)dp[u][flag]=ans; return ans; }int solve (int n) { int cnt = 0 ; while (n){ p[++cnt]=n%10 , n/=10 ; } memset (dp, -1 , sizeof (dp)); return dfs (cnt,0 ,1 ,1 ); }int main () { ios::sync_with_stdio (0 );cin.tie (0 ); int a,b; while (cin>>a>>b){ cout << solve (b) - solve (a-1 ) << '\n' ; } return 0 ; }
需要注意的是,我们对 lim , zero \textit{lim},\textit{zero} lim , zero 有一个为 true
的情况没有记忆化保存,这是因为这两个参数有任意一个为 true
的时候要么是一直贴着上界,要么是一直填 0 0 0 ,两种情况都不会被重复访问,所以不记忆化对复杂度也没有影响。(当然,记忆化 4 4 4 个所有参数也是正确的)
例题
接下来我们会通过一些例题来讲解记忆化搜索中的一些变式。
例题 1
LOJ#10168. 「一本通 5.3 练习 3」恨 7 不成妻
求 [ L , R ] [L,R] [ L , R ] 中满足一下条件的整数的平方和 :
不含有数字 7 7 7 ;
不是 7 7 7 的倍数;
数位和不是 7 7 7 的倍数。
1 ≤ L ≤ R ≤ 1 0 18 1\le L \le R \le 10^{18} 1 ≤ L ≤ R ≤ 1 0 18 ,对 1 0 9 + 7 10^9 +7 1 0 9 + 7 取模。
先考虑如何设计本题的 flag \textit{flag} flag ,第一个条件很好处理,只需要不枚举 7 7 7 就可以了,第二个条件要求我们储存当前的数对 7 7 7 的余数,第三个条件要求我们储存当前的数位和对 7 7 7 的余数,最后在填完所有数的时候进行判断即可。
因此本体需要两个 flag \textit{flag} flag ,分别表示当前填的这些数位所代表的数对 7 7 7 的余数和当前填的数位的数位和对 7 7 7 的余数。
本题还有一个变化点在于要求的是数的平方和,这个如何处理呢?
考虑我们已经求出了填后面 u − 1 u-1 u − 1 位的结果,现在第 u u u 位上填的是 i i i ,要求的就是 ∑ x ∈ Ans ( u − 1 ) ( i ⋅ 1 0 u − 1 + x ) 2 \displaystyle\sum_{x\in \textit{Ans}(u-1)} (i\cdot10^{u-1} + x)^2 x ∈ Ans ( u − 1 ) ∑ ( i ⋅ 1 0 u − 1 + x ) 2 ,其中 Ans ( u − 1 ) \textit{Ans}(u-1) Ans ( u − 1 ) 表示后面 u − 1 u-1 u − 1 位上合法每个数的集合。
因为有二项式定理,展开我们发现,只需要知道 Ans ( u − 1 ) \textit{Ans}(u-1) Ans ( u − 1 ) 中元素的 0 , 1 , 2 0,1,2 0 , 1 , 2 次方和就能求出这个。(对于求当前的 0 , 1 0,1 0 , 1 次方和同理)
因此我们对于每个记忆化搜索,都保存当前答案的 0 , 1 , 2 0,1,2 0 , 1 , 2 次方,在更新时用二项式定理维护即可解决求平方和的问题。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 #include <bits/stdc++.h> using namespace std;using ll=long long ; const ll MOD = 1e9 +7 ;ll fastpow (ll a,ll b) { ll res = 1 ; while (b){ if (b&1 )res=(res*a)%MOD; a=(a*a)%MOD; b>>=1 ; } return res; }struct Node { ll cnt,sum,sqsum; };int p[25 ]; Node dp[25 ][7 ][7 ]; ll pow10_mod7[25 ]; Node dfs (ll u, bool lim, bool zero,ll r1,ll r2) { if (u==0 ){ if (r1==0 || r2==0 )return Node{0 ,0 ,0 }; else return Node{1 ,0 ,0 }; } if (!lim && !zero && ~dp[u][r1][r2].cnt) return dp[u][r1][r2]; Node ans{0 ,0 ,0 }; int up=lim?p[u]:9 ; for (int i=0 ;i<=up;i++){ if (i==7 )continue ; ll k = fastpow (10 ,u-1 )*i%MOD; auto res = dfs (u-1 ,lim&&i==p[u],zero&&i==0 , (r1+i)%7 ,(r2+pow10_mod7[u-1 ]*i)%7 ); ans.cnt = (ans.cnt+res.cnt)%MOD; ans.sum = (ans.sum+res.sum)%MOD; ans.sum = (ans.sum+res.cnt*k%MOD)%MOD; ans.sqsum = (ans.sqsum+res.sqsum)%MOD; ans.sqsum = (ans.sqsum+k*2 *res.sum)%MOD; ans.sqsum = (ans.sqsum+k*k%MOD*res.cnt)%MOD; } if (!zero && !lim)dp[u][r1][r2]=ans; return ans; }ll solve (ll n) { memset (dp,-1 ,sizeof (dp)); int cnt = 0 ; while (n){ p[++cnt]=n%10 ; n/=10 ; } return dfs (cnt,1 ,1 ,0 ,0 ).sqsum; }int main () { ios::sync_with_stdio (0 );cin.tie (0 ); pow10_mod7[0 ]=1 ; for (int i=1 ;i<25 ;i++){ pow10_mod7[i]=pow10_mod7[i-1 ]*10 %7 ; } int T; cin>>T; while (T--){ ll a,b; cin>>a>>b; cout << (solve (b)-solve (a-1 )+MOD)%MOD << '\n' ; } return 0 ; }
例题 2
ABC317F Nim
给定 N , A 1 , A 2 , A 3 N,A_1,A_2,A_3 N , A 1 , A 2 , A 3 ,求满足以下条件的三元组 ( X 1 , X 2 , X 3 ) (X_1,X_2,X_3) ( X 1 , X 2 , X 3 ) 个数,对 998244353 998244353 998244353 取模:
∀ i , 1 ≤ X i ≤ N \forall i, 1\le X_i \le N ∀ i , 1 ≤ X i ≤ N ;
∀ i \forall i ∀ i ,X i X_i X i 是 A i A_i A i 的倍数;
X 1 ⊕ X 2 ⊕ X 3 = 0 X_1\oplus X_2 \oplus X_3 = 0 X 1 ⊕ X 2 ⊕ X 3 = 0 。
N ≤ 1 0 18 , 1 ≤ A i ≤ 10 N\le 10^{18},1\le A_i\le 10 N ≤ 1 0 18 , 1 ≤ A i ≤ 10 。
本题的关键在于求出三元组的个数。之前的题目都是求出数的个数,事实上,之前的方法完全可以用来做求 n n n 元组个数。
我们只需要将 f ( u , flag , lim , zero ) f(u,\textit{flag},\textit{lim},\textit{zero}) f ( u , flag , lim , zero ) 修改成 f ( u , flag 1 , flag 2 , flag 3 , lim 1 , lim 2 , lim 3 , zero 1 , zero 2 , zero 3 ) f(u,\textit{flag}_1,\textit{flag}_2,\textit{flag}_3,\textit{lim}_1,\textit{lim}_2,\textit{lim}_3,\textit{zero}_1,\textit{zero}_2,\textit{zero}_3) f ( u , flag 1 , flag 2 , flag 3 , lim 1 , lim 2 , lim 3 , zero 1 , zero 2 , zero 3 ) 即可。
同时在枚举第 u u u 位的时候,要枚举的是 3 3 3 个数位共 b 3 b^3 b 3 种情况(其中 b b b 是进制,本题因为有异或的条件,应该采用 2 2 2 进制)
特别需要注意,此时所有参数都要记忆化,因为一个数贴着上限其他数还有可能有很多种可能,不再是只会访问一次的了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 #include <bits/stdc++.h> using namespace std;using ll=long long ;const int MOD = 998244353 ;int dp[80 ][12 ][12 ][12 ][2 ][2 ][2 ][2 ][2 ][2 ]; int p[80 ]; int a1,a2,a3;int dfs (int u,int r1,int r2,int r3,bool l1,bool l2,bool l3, bool z1,bool z2,bool z3) { if (u==0 ){ return !z1 &&!z2&&!z3 && !r1&&!r2&&!r3; } if (dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3]!=-1 ) return dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3]; int up1=l1?p[u]:1 ,up2=l2?p[u]:1 ,up3=l3?p[u]:1 ; int ans = 0 ; for (ll i=0 ;i<=up1;i++){ for (ll j=0 ;j<=up2;j++){ for (ll k=0 ;k<=up3;k++){ if ((i^j^k)!=0 )continue ; int newr1 = ((ll)r1+(i<<(u-1 )))%a1; int newr2 = ((ll)r2+(j<<(u-1 )))%a2; int newr3 = ((ll)r3+(k<<(u-1 )))%a3; ans = (ans +dfs (u-1 ,newr1,newr2,newr3, l1&&i==p[u],l2&&j==p[u],l3&&k==p[u], z1&&i==0 ,z2&&j==0 ,z3&&k==0 ))%MOD; } } } return dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3]=ans; }int main () { ios::sync_with_stdio (0 );cin.tie (0 ); memset (dp,-1 ,sizeof (dp)); ll n; cin>>n>>a1>>a2>>a3; int cnt = 0 ; ll x=n; while (x) p[++cnt]=x%2 , x>>=1 ; ll ans = dfs (cnt,0 ,0 ,0 ,1 ,1 ,1 ,1 ,1 ,1 ); cout << ans << '\n' ; return 0 ; }
练习
参考资料