数位 DP 
数位 DP 是一种快速求解出在 [ 1 , n ] [1,n] [ 1 , n ]   满足条件的数个数的方法。其特点有数据范围一般均为 1 0 18 10^{18} 1 0 18   次方量级(甚至更高)。
对于这种问题,我们是有一套通用的记忆化搜索模板的。
记忆化搜索 
数位 DP 有不使用记忆化搜索的递推写法,但是记忆化搜索写法具有好实现的优点,而且一般记忆化搜索是可以做所有数位 DP 的题目的。
我们从一道题目出发来讲解通用的记忆化搜索方法。
示例题目 
LOJ#10164. 「一本通 5.3 例 2」数字游戏 
给定两个正整数 a a a   和 b b b  ,求在 [ a , b ] [a,b] [ a , b ]   中的所有整数中,有多少个数满足从左到右各位数字成小于等于的关系。
1 ≤ a ≤ b ≤ 2 31 − 1 1\le a\le b\le 2^{31}-1 1 ≤ a ≤ b ≤ 2 31 − 1  。
解法 
首先我们需要知道,对于区间 [ a , b ] [a,b] [ a , b ]   来说,可以拆分成求 [ 1 , b ] [1,b] [ 1 , b ]   和 [ 1 , a − 1 ] [1,a-1] [ 1 , a − 1 ]   的答案,然后将结果相减,因此后面我们只需考虑求出 [ 1 , n ] [1,n] [ 1 , n ]   的所有整数中各数码出现个数。(这种前缀和的思想在数位 DP 中是很通用的,大多数题目都可以这么转化)
对于记忆化搜索来说,标准的格式是 f ( u , flag , lim , zero ) f(u,\textit{flag},\textit{lim},\textit{zero}) f ( u , flag , lim , zero )  ,其中 u u u   代表的是搜索到哪一位,flag \textit{flag} flag   代表题目所需要的信息,lim \textit{lim} lim   代表是否贴着 n n n   的上限,zero \textit{zero} zero   代表目前是否全是 0 0 0  。
对于本题来说,flag \textit{flag} flag   的信息就是上一次填入的数字,这样就可以处理题目各位数字不下降的要求。
对于记忆化搜索的转移,每次枚举当前填的数码(借助 flag , lim \textit{flag},\textit{lim} flag , lim   的信息确定能填哪些数码),然后将所有 f ( u − 1 , ⋅ , ⋅ , ⋅ ) f(u-1,\cdot ,\cdot,\cdot) f ( u − 1 , ⋅ , ⋅ , ⋅ )   的结果进行求和返回。
来看具体代码(请阅读注释了解更多信息):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 #include <bits/stdc++.h>  using  namespace  std;const  int  N = 12 ;int  p[N];int  dp[N][10 ];int  dfs (int  u,int  flag,bool  lim,bool  zero)  {     if (u==0 ){          return  !zero;      }          if (!lim && !zero && ~dp[u][flag])return  dp[u][flag];          int  L = flag,            R = (lim?p[u]:9 );           int  ans = 0 ;          for (int  i=L;i<=R;i++){                  ans += dfs (u-1 , i, lim&&(i==p[u]), zero&&!i);     }     if (!lim && !zero)dp[u][flag]=ans;     return  ans; }int  solve (int  n)  {     int  cnt = 0 ;      while (n){          p[++cnt]=n%10 , n/=10 ;     }     memset (dp, -1 , sizeof (dp));      return  dfs (cnt,0 ,1 ,1 );  }int  main ()  {     ios::sync_with_stdio (0 );cin.tie (0 );     int  a,b;     while (cin>>a>>b){         cout << solve (b) - solve (a-1 ) << '\n' ;     }     return  0 ; }
 
需要注意的是,我们对 lim , zero \textit{lim},\textit{zero} lim , zero   有一个为 true 的情况没有记忆化保存,这是因为这两个参数有任意一个为 true 的时候要么是一直贴着上界,要么是一直填 0 0 0  ,两种情况都不会被重复访问,所以不记忆化对复杂度也没有影响。(当然,记忆化 4 4 4   个所有参数也是正确的)
例题 
接下来我们会通过一些例题来讲解记忆化搜索中的一些变式。
例题 1 
LOJ#10168. 「一本通 5.3 练习 3」恨 7 不成妻 
求 [ L , R ] [L,R] [ L , R ]   中满足一下条件的整数的平方和 :
不含有数字 7 7 7  ; 
不是 7 7 7   的倍数; 
数位和不是 7 7 7   的倍数。 
 
1 ≤ L ≤ R ≤ 1 0 18 1\le L \le R \le 10^{18} 1 ≤ L ≤ R ≤ 1 0 18  ,对 1 0 9 + 7 10^9 +7 1 0 9 + 7   取模。
 
先考虑如何设计本题的 flag \textit{flag} flag  ,第一个条件很好处理,只需要不枚举 7 7 7   就可以了,第二个条件要求我们储存当前的数对 7 7 7   的余数,第三个条件要求我们储存当前的数位和对 7 7 7   的余数,最后在填完所有数的时候进行判断即可。
因此本体需要两个 flag \textit{flag} flag  ,分别表示当前填的这些数位所代表的数对 7 7 7   的余数和当前填的数位的数位和对 7 7 7   的余数。
本题还有一个变化点在于要求的是数的平方和,这个如何处理呢?
考虑我们已经求出了填后面 u − 1 u-1 u − 1   位的结果,现在第 u u u   位上填的是 i i i  ,要求的就是 ∑ x ∈ Ans ( u − 1 ) ( i ⋅ 1 0 u − 1 + x ) 2 \displaystyle\sum_{x\in \textit{Ans}(u-1)} (i\cdot10^{u-1} + x)^2 x ∈ Ans ( u − 1 ) ∑  ( i ⋅ 1 0 u − 1 + x ) 2  ,其中 Ans ( u − 1 ) \textit{Ans}(u-1) Ans ( u − 1 )   表示后面 u − 1 u-1 u − 1   位上合法每个数的集合。
因为有二项式定理,展开我们发现,只需要知道 Ans ( u − 1 ) \textit{Ans}(u-1) Ans ( u − 1 )   中元素的 0 , 1 , 2 0,1,2 0 , 1 , 2   次方和就能求出这个。(对于求当前的 0 , 1 0,1 0 , 1   次方和同理)
因此我们对于每个记忆化搜索,都保存当前答案的 0 , 1 , 2 0,1,2 0 , 1 , 2   次方,在更新时用二项式定理维护即可解决求平方和的问题。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 #include <bits/stdc++.h>  using  namespace  std;using  ll=long  long ; const  ll MOD = 1e9 +7 ;ll fastpow (ll a,ll b)  {      ll res = 1 ;     while (b){         if (b&1 )res=(res*a)%MOD;         a=(a*a)%MOD;         b>>=1 ;     }     return  res; }struct  Node {      ll cnt,sum,sqsum; };int  p[25 ]; Node dp[25 ][7 ][7 ]; ll pow10_mod7[25 ]; Node dfs (ll u, bool  lim, bool  zero,ll r1,ll r2)  {     if (u==0 ){         if (r1==0  || r2==0 )return  Node{0 ,0 ,0 };         else  return  Node{1 ,0 ,0 };     }     if (!lim && !zero && ~dp[u][r1][r2].cnt)         return  dp[u][r1][r2];     Node ans{0 ,0 ,0 };     int  up=lim?p[u]:9 ;     for (int  i=0 ;i<=up;i++){         if (i==7 )continue ;         ll k = fastpow (10 ,u-1 )*i%MOD;         auto  res = dfs (u-1 ,lim&&i==p[u],zero&&i==0 ,                     (r1+i)%7 ,(r2+pow10_mod7[u-1 ]*i)%7 );         ans.cnt = (ans.cnt+res.cnt)%MOD;                  ans.sum = (ans.sum+res.sum)%MOD;         ans.sum = (ans.sum+res.cnt*k%MOD)%MOD;         ans.sqsum = (ans.sqsum+res.sqsum)%MOD;         ans.sqsum = (ans.sqsum+k*2 *res.sum)%MOD;         ans.sqsum = (ans.sqsum+k*k%MOD*res.cnt)%MOD;     }          if (!zero && !lim)dp[u][r1][r2]=ans;     return  ans; }ll solve (ll n)  {     memset (dp,-1 ,sizeof (dp));     int  cnt = 0 ;     while (n){         p[++cnt]=n%10 ;         n/=10 ;     }     return  dfs (cnt,1 ,1 ,0 ,0 ).sqsum; }int  main ()  {      ios::sync_with_stdio (0 );cin.tie (0 );     pow10_mod7[0 ]=1 ;     for (int  i=1 ;i<25 ;i++){         pow10_mod7[i]=pow10_mod7[i-1 ]*10 %7 ;     }     int  T; cin>>T;     while (T--){         ll a,b; cin>>a>>b;         cout << (solve (b)-solve (a-1 )+MOD)%MOD << '\n' ;     }     return  0 ; }
 
例题 2 
ABC317F Nim 
给定 N , A 1 , A 2 , A 3 N,A_1,A_2,A_3 N , A 1  , A 2  , A 3   ,求满足以下条件的三元组 ( X 1 , X 2 , X 3 ) (X_1,X_2,X_3) ( X 1  , X 2  , X 3  )   个数,对 998244353 998244353 998244353   取模:
∀ i , 1 ≤ X i ≤ N \forall i, 1\le X_i \le N ∀ i , 1 ≤ X i  ≤ N  ; 
∀ i \forall i ∀ i  ,X i X_i X i    是 A i A_i A i    的倍数; 
X 1 ⊕ X 2 ⊕ X 3 = 0 X_1\oplus X_2 \oplus X_3 = 0 X 1  ⊕ X 2  ⊕ X 3  = 0  。 
 
N ≤ 1 0 18 , 1 ≤ A i ≤ 10 N\le 10^{18},1\le A_i\le 10 N ≤ 1 0 18 , 1 ≤ A i  ≤ 10  。
 
本题的关键在于求出三元组的个数。之前的题目都是求出数的个数,事实上,之前的方法完全可以用来做求 n n n   元组个数。
我们只需要将 f ( u , flag , lim , zero ) f(u,\textit{flag},\textit{lim},\textit{zero}) f ( u , flag , lim , zero )   修改成 f ( u , flag 1 , flag 2 , flag 3 , lim 1 , lim 2 , lim 3 , zero 1 , zero 2 , zero 3 ) f(u,\textit{flag}_1,\textit{flag}_2,\textit{flag}_3,\textit{lim}_1,\textit{lim}_2,\textit{lim}_3,\textit{zero}_1,\textit{zero}_2,\textit{zero}_3) f ( u , flag 1  , flag 2  , flag 3  , lim 1  , lim 2  , lim 3  , zero 1  , zero 2  , zero 3  )   即可。
同时在枚举第 u u u   位的时候,要枚举的是 3 3 3   个数位共 b 3 b^3 b 3   种情况(其中 b b b   是进制,本题因为有异或的条件,应该采用 2 2 2   进制)
特别需要注意,此时所有参数都要记忆化,因为一个数贴着上限其他数还有可能有很多种可能,不再是只会访问一次的了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 #include <bits/stdc++.h>  using  namespace  std;using  ll=long  long ;const  int  MOD = 998244353 ;int  dp[80 ][12 ][12 ][12 ][2 ][2 ][2 ][2 ][2 ][2 ]; int  p[80 ]; int  a1,a2,a3;int  dfs (int  u,int  r1,int  r2,int  r3,bool  l1,bool  l2,bool  l3,         bool  z1,bool  z2,bool  z3)  {     if (u==0 ){                            return  !z1 &&!z2&&!z3 && !r1&&!r2&&!r3;     }          if (dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3]!=-1 )         return  dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3];               int  up1=l1?p[u]:1 ,up2=l2?p[u]:1 ,up3=l3?p[u]:1 ;     int  ans = 0 ;     for (ll i=0 ;i<=up1;i++){         for (ll j=0 ;j<=up2;j++){             for (ll k=0 ;k<=up3;k++){                 if ((i^j^k)!=0 )continue ;                                                    int  newr1 = ((ll)r1+(i<<(u-1 )))%a1;                 int  newr2 = ((ll)r2+(j<<(u-1 )))%a2;                 int  newr3 = ((ll)r3+(k<<(u-1 )))%a3;                                                   ans = (ans                 +dfs (u-1 ,newr1,newr2,newr3,                     l1&&i==p[u],l2&&j==p[u],l3&&k==p[u],                     z1&&i==0 ,z2&&j==0 ,z3&&k==0 ))%MOD;             }         }     }     return  dp[u][r1][r2][r3][l1][l2][l3][z1][z2][z3]=ans;  }int  main ()  {     ios::sync_with_stdio (0 );cin.tie (0 );     memset (dp,-1 ,sizeof (dp));          ll n;      cin>>n>>a1>>a2>>a3;     int  cnt = 0 ;     ll x=n;     while (x) p[++cnt]=x%2 , x>>=1 ;     ll ans = dfs (cnt,0 ,0 ,0 ,1 ,1 ,1 ,1 ,1 ,1 );     cout << ans << '\n' ;     return  0 ; }
 
练习 
参考资料