kuangbin带你飞 专题十五 数位DP

目录

CodeForces 55D Beautiful numbers

大意 :

给出n和m,求出n和m之间的美丽数的数量。

美丽数的定义是每个数位上的数都能整除这个数

思路:

首先需要知道的是,一个数能被每个数位整除,那么它能被所有数位的lcm整除

其次,因为1到9的lcm为2520,那么判断一个数能否被某个数位整除,可以先将这个数对2520取模

然后就可以解决这个问题了,\(dp[pos][stat][lcm]\)代表当前为pos位,已经取的数位的lcm为lcm,已经取得数对2520取模为stat

但是lcm没必要开2520,因为现在取到的数,肯定是2520的因数,所以可以将lcm离散化,2520的因数只有不到50个,这样就把一个维度从2520压缩到了50

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[100], cnt, prime[2550], pc;
const int mod = 2520;
LL dp[32][2550][50];

LL dfs(int pos, int limit, int lcm, int stat) {
    if (pos == -1) {
        return stat % lcm == 0;
    }
    if ((!limit)&&dp[pos][stat][prime[lcm]] != -1)
        return dp[pos][stat][prime[lcm]];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;

    LL res = 0;
    for (int i = 0; i <= up; i++) {
        int ne = lcm;
        if (i) ne = (lcm * i) / __gcd(lcm, i);
        res += dfs(pos - 1, limit && (i == num[pos]), ne, (stat * 10 + i) % 2520);
    }
    if(!limit)
    dp[pos][stat][prime[lcm]] = res;
    return res;
}

LL solve(LL n) {
    //memset(dp, -1, sizeof dp);
    cnt = 0;
    while (n) {
        num[cnt++] = n % 10;
        n /= 10;
    }
    return dfs(cnt - 1, 1, 1, 0);
}

int main() {
    LL n, m;
    int t;
    memset(dp, -1, sizeof dp);
    cin >> t;
    for (int i = 1; i <= mod; i++) {
        if (mod % i == 0) prime[i] = pc++;
    }
    while (t--) {
        cin >> n >> m;
        cout << solve(m) - solve(n - 1) << endl;
    }
    return 0;
}

HDU 4352 XHXJ's LIS

大意 :

要求找出n和m之间满足:数位中最长上升子序列的长度为k的数 的个数

k<=10

思路:

此题需要记录状态,那么考虑状态压缩,即维护当前lis的状态,由于数位只可能是0到9,所以我们可以利用一个十位的二进制数,表示当前lis中出现了哪些数,然后利用类似\(O(nlog(n))\)求lis的方法去更新这个状态即可,即每次在lis里找比x大的数,然后删掉这个数,用x去替换

注意需要记录前导零,因为如果当前有前导零,且当前的数也是0,那么0就不能加入到lis里面,否则0就可以加入到lis里面

另外由于测试数据很多,但是k只有10中可能性,所以可以在记忆化的时候记录k,否则超时

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[32], cnt;

LL dp[32][3000][15];
int k;
int getnum(int n) {
    int res = 0;
    for (int i = 0; i <= 9; i++) {
        if (n & (1 << i)) res++;
    }
    return res;
}

int update(int stat, int x) {
    for (int i = x; i <= 9; i++) {
        if (stat & (1 << i)) {
            stat -= (1 << i);
            stat += (1 << x);
            return stat;
        }
    }
    return stat + (1 << x);
}

LL dfs(int pos, int stat, int limit, int pre) {//pre记录前导0
    if (pos == -1) {
        return getnum(stat)==k;
    }
    if (!limit && dp[pos][stat][k] != -1) return dp[pos][stat][k];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;
    LL res = 0;
    for (int i = 0; i <= up; i++) {
        if (pre == 1 && i == 0)
            res += dfs(pos - 1, 0, limit && i == num[pos], 1);
        else
            res += dfs(pos - 1, update(stat, i), limit && i == num[pos], 0);
    }
    if (!limit) dp[pos][stat][k] = res;
    return res;
}
 
LL solve(LL n) {
    cnt = 0;
    while (n) {
        num[cnt++] = n % 10;
        n /= 10;
    }
    return dfs(cnt - 1, 0, 1, 1);
}

int main() {
    LL n, m;
    int t;
    cin >> t;
    int cases = 0;
    memset(dp, -1, sizeof dp);
    while (t--) {
        cases++;
        cin >> n >> m >> k;
        
        cout << "Case #" << cases << ": ";
        cout << solve(m) - solve(n - 1) << endl;
    }
    return 0;
}

HDU 2089 不要62

大意 :

找出n到m中的吉利数,不吉利的数字为所有含有4或62的号码。

思路:

模板数位DP

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[20], cnt, dp[20][2];

int dfs(int pos, int stat, int limit) {
    if (pos == -1) return 1;
    if (!limit && dp[pos][stat] != -1) return dp[pos][stat];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;
    int res = 0;
    for (int i = 0; i <= up; i++) {
        if (i == 4) continue;
        if (stat && i == 2) continue;
        res += dfs(pos - 1, i == 6, limit && i == num[pos]);
    }
    if (!limit) dp[pos][stat] = res;
    return res;
}

int solve(int n) {
    cnt = 0;
    while (n) {
        num[cnt++] = n % 10;
        n /= 10;
    }
    return dfs(cnt - 1, 0, 1);
}

int main() {
    int n, m;
    memset(dp, -1, sizeof dp);
    while (scanf("%d%d", &n, &m) && (n + m != 0)) {
        printf("%d\n", solve(m) - solve(n - 1));
    }
    return 0;
}

HDU 3555 Bomb

大意 :

要求求出1到N中 数位里出现过49的数 的数量

思路:

dfs(pos, stat,limit,pre)代表第pos位,是否出现过49(stat),是否达到上界,前一个数是多少

注意dp数组也要相应开3维,分别记录pos,stat,pre

一开始没有开pre这一维,但是情况会被覆盖,导致全0

总结一下就是dfs除了limit以外有多少参数dp就要开多少维来记录状态,否则状态会重叠

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[32], cnt;

LL dp[32][2][20];

LL dfs(int pos, int stat, int limit,int pre) {
    //cout << stat << endl;
    if (pos == -1) return stat;
    if (!limit && dp[pos][stat][pre] != -1) return dp[pos][stat][pre];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;
    LL res = 0;
    for (int i = 0; i <= up; i++) {
        //int ne = stat | (pre == 4 && i == 9);
        if(pre == 4 && i == 9)
        res += dfs(pos - 1, 1, limit && i == num[pos],i);
        else
        res += dfs(pos - 1, stat, limit && i == num[pos],i);
    }
    if (!limit) dp[pos][stat][pre] = res;
    return res;
}

LL solve(LL n) {
    cnt = 0;
    while (n) {
        num[cnt++] = n % 10;
        n /= 10;
    }
    return dfs(cnt - 1, 0, 1,0);
}

int main() {
    int t;
    LL n;
    memset(dp, -1, sizeof dp);
    cin >> t;
    while (t--) {
        cin >> n;
        cout << solve(n) << endl;
    }
    return 0;
}

POJ 3252 Round Numbers

大意 :

求出n到m之间的Round数,Round数满足:二进制表示中0的个数大于等于1的个数

思路:

\(dfs(pos, num1, num0, limit,pre)\),

num1和num0来记录当前二进制中0和1的个数,pre来记录是否为前导0,只有不是前导0的情况才能把0加入num0中

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <vector>
#include <queue>
#include <string>


using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[70], cnt;

LL dp[70][70][70][2];

LL dfs(int pos, int num1, int num0, int limit,int pre) {
    if (pos == -1) return num0 >= num1;
    if (!limit && dp[pos][num1][num0][pre] != -1) return dp[pos][num1][num0][pre];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 1;
    LL res = 0;
    for (int i = 0; i <= up; i++) {
        res += dfs(pos - 1, num1 + (i == 1), num0 + (pre==0&&i == 0),limit && i == num[pos],pre==1&&i==0);
    }
    if (!limit) dp[pos][num1][num0][pre] = res;
    return res;
}

LL solve(LL n) {
    cnt = 0;
    while (n) {
        num[cnt++] = n % 2;
        n /= 2;
    }
    return dfs(cnt - 1, 0, 0, 1,1);
}

int main() {
    LL n, m;
    memset(dp, -1, sizeof dp);
    cin >> n >> m;
    cout << solve(m) - solve(n - 1) << endl;
    return 0;
}

HDU 3709 Balanced Number

大意 :

求出n和m之间的平衡数的数量,平衡数的定义是能找到一个数位,使得左边数位的值 乘 位置差=右边数位的值 乘 位置差

思路:

对于每平衡数,有且仅有一个平衡点,那么可以枚举平衡的位置,进行dfs,dfs时还是从左到右试填,但是需要记录数位的值乘上位置差的和,最后看这个和是不是0即可,又可以考虑到,当这个和小于0时必然无解,因为越过平衡点后,越靠右这个和肯定是越小的

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[20], cnt;

LL dp[20][20][5000];

LL dfs(int pos, int stat, int mid, int limit) {
    if (pos == -1) return stat == 0;
    if (stat < 0) return 0;
    if (!limit && dp[pos][mid][stat] != -1) return dp[pos][mid][stat];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;
    LL res = 0;
    for (int i = 0; i <= up; i++) {
        res += dfs(pos - 1, stat + i * (pos - mid),mid, limit && (i == num[pos]));
    }
    if (!limit) dp[pos][mid][stat] = res;
    return res;
}

LL solve(LL n) {
    cnt = 0;
    while (n) {
        num[cnt++] = n % 10;
        n /= 10;
    }
    LL res = 0;
    for (int i = 0; i < cnt; i++) {
        res += dfs(cnt - 1, 0, i, 1);
    }
    return res - (cnt - 1);
}

int main() {
    int t;
    LL n, m;
    memset(dp, -1, sizeof dp);
    cin >> t;
    while (t--) {
        cin >> n >> m;
        cout << solve(m) - solve(n - 1) << endl;
    }
    return 0;
}

HDU 3652 B-number

大意 :

求出0到n之间的B-数,B-数的定义是含有13且能被13整除的数

思路:

和Beautiful numbers那道题有点像,记录当前对13取余的值,以及是否出现过13,以及前一个数是不是1即可

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[20], cnt;

LL dp[20][2][2][20];

LL dfs(int pos, int stat, int have, int pre,int limit) {
    if (pos == -1) return have && stat %13==0;
    if (!limit && dp[pos][have][pre][stat] != -1) return dp[pos][have][pre][stat];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;
    LL res = 0;
    for (int i = 0; i <= up; i++) {
        res += dfs(pos - 1, (stat*10+i)%13,have|(pre==1&&i==3), (i==1),limit && (i == num[pos]));
    }
    if (!limit) dp[pos][have][pre][stat] = res;
    return res;
}

LL solve(LL n) {
    cnt = 0;
    while (n) {
        num[cnt++] = n % 10;
        n /= 10;
    }
    return dfs(cnt-1,0,0,0,1);
}

int main() {
    int t;
    LL n, m;
    memset(dp, -1, sizeof dp);
    //cin >> t;
    while (scanf("%lld",&n)!=EOF) {
        cout << solve(n) << endl;
    }
    return 0;
}

HDU 4734 F(x)

大意 :

给出两个数n和m,求0到m之间f(x)小于f(n)的数量

f函数的定义;

\(F(x) = A_n * 2^{n-1} + A_{n-1} * 2^{n-2} + ... + A_2 * 2 + A_1 * 1\)

思路:

常规的想法是\(dp[pos][stat]\)来记录前pos位,前缀状态为stat下的数量,状态只有不到5000个,可以保存,但是由于n不同,每次都需要清空dp数组,这样会超时,那么又可以想到加一维记录n,但是这样数组大小开到了1e8,内存会爆。

可以利用做差来优化,\(dp[pos][fa-stat]\)记录还差fa-stat达到fa的状态,这样不管n怎么变都不影响dp数组,但是需要注意的是,这样stat就不能写成\((stat<<1)+i\)来转移,必须要写成\(stat+(i<<pos)\)

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[32], cnt;
int numa;
map<int, int> mp;
LL dp[10][5000];

int getnum(LL x) {
    int res = 0;
    int base = 1;
    while(x){
        int tmp = x % 10;
        res += tmp * base;
        base *= 2;
        x /= 10;
    }
    return res;
}

//int update(int stat, int x) { return stat << 1 + x; }

LL dfs(int pos, int stat, int limit,int fa) {
    if (pos == -1) return stat<=fa;
    if (stat > fa) return 0;
    if (!limit && dp[pos][fa-stat]!= -1)
        return dp[pos][fa-stat];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;
    LL res = 0;
    for (int i = 0; i <= up; i++) {
        res += dfs(pos - 1, stat + (i<<pos), limit && (i == num[pos]), fa);
    }
    if (!limit) dp[pos][fa-stat] = res;
    return res;
}

LL solve(LL n, LL m) {
    //memset(dp, -1, sizeof dp);
    cnt = 0;
    int tmp = getnum(n);
    
    while (m) {
        num[cnt++] = m % 10;
        m /= 10;
    }
    return dfs(cnt - 1, 0, 1,tmp);
}

int main() {
    int t;
    LL n, m;
    memset(dp, -1, sizeof dp);
    cin >> t;
    int cases = 0;
    while (t--) {
        cases++;
        cin >> n >> m;
        cout << "Case #" << cases << ": ";
        cout << solve(n, m) << endl;
    }
    return 0;
}

ZOJ 3494 BCD Code

大意 :

求出n到m之间的数以bcd码表示中不含被禁的字符串的数量

思路:

AC自动机+数位dp即可

HDU 4507 吉哥系列故事――恨7不成妻

大意 :

求出n到m内和7无关的数字的平方和。

如果一个整数符合下面3个条件之一,那么我们就说这个整数和7有关——
1、整数中某一位是7;
2、整数的每一位加起来的和是7的整数倍;
3、这个整数是7的整数倍;

思路:

很容易想到\(dp[pos][stat][sum]\)代表前缀为stat%7,数位的和为sum%7

但是要求平方和,上面的dp只能求个数,求不了平方和,所以可以利用类似线段树求平方和的方法,用结构体维护cnt,sum,sqrsum

转移时将\((A+B)^2\)拆开分析即可,具体看代码(注意输出时要+mod再取模):

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[32], cnt;
struct node {
    LL cnt;     //记录个数
    LL sum;     //记录和
    LL sqrsum;  //记录平方和
} dp[25][10][10];
LL pw[25];

const LL mod = 1e9 + 7;

void init() {
    pw[0] = 1;
    for (int i = 1; i <= 20; i++) {
        pw[i] = pw[i - 1] * 10 % mod;
    }
}

node dfs(int pos, int stat, int limit,
         int sum) {  // stat记录整体mod7 sum记录数位和mod7
    if (pos == -1) {
        node tmp;
        tmp.cnt = 1;
        tmp.sum = 0;
        tmp.sqrsum = 0;
        if (sum % 7 == 0) tmp.cnt = 0;
        if (stat % 7 == 0) tmp.cnt = 0;
        return tmp;
    }
    if (!limit && dp[pos][stat][sum].cnt != -1) return dp[pos][stat][sum];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;
    node res;
    res.cnt = res.sqrsum = res.sum = 0;
    for (int i = 0; i <= up; i++) {
        if (i == 7) continue;
        node temp;
        temp = dfs(pos - 1, (stat * 10 + i) % 7, limit && i == num[pos],
                   (sum + i) % 7);

        res.cnt += temp.cnt;
        res.cnt %= mod;

        res.sum += (temp.sum + temp.cnt * (pw[pos] * i % mod) % mod) % mod;
        res.sum %= mod;

        res.sqrsum +=
            ((temp.sqrsum + 2LL * (pw[pos] * i % mod * temp.sum % mod) % mod) %
                 mod +
             temp.cnt * pw[pos] % mod * pw[pos] % mod * i % mod * i % mod) %
            mod;
        res.sqrsum %= mod;
    }
    if (!limit) dp[pos][stat][sum] = res;
    return res;
}

node solve(LL n) {
    cnt = 0;
    while (n) {
        num[cnt++] = n % 10;
        n /= 10;
    }
    return dfs(cnt - 1, 0, 1, 0);
}

int main() {
    int t;
    LL n, m;
    init();
    memset(dp, -1, sizeof dp);
    cin >> t;
    while (t--) {
        cin >> n >> m;
        cout << (solve(m).sqrsum - solve(n - 1).sqrsum + mod) % mod << endl;
    }
    return 0;
}

SPOJ BALNUM Balanced Numbers

大意 :

求出n到m之间的平衡数,平衡数的定义是数位上每个奇数都出现偶数次,每个偶数都出现奇数次

思路:

因为要同时记录10个数位的状态,可以利用三进制进行状态压缩,0代表没出现过,1代表出现过奇数次,2代表出现过偶数次,这样总状态只有\(3^{10}\)不到6e4个,完全够用

注意一下三进制状态压缩时的judge和update

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 5;
typedef long long LL;
int num[25], cnt;

LL dp[25][60000];

bool judge(int stat) {
    for (int i = 0; i <= 9; i++) {
        if (i % 2 == 1 && stat % 3 == 1) return 0;
        if (i % 2 == 0 && stat % 3 == 2) return 0;
        stat /= 3;
    }
    return 1;
}

int update(int stat, int x) {
    int temp = stat / (pow(3, x));
    stat -= temp * pow(3, x);
    if (temp % 3 == 0)
        temp++;
    else if (temp % 3 == 1)
        temp++;
    else
        temp--;
    stat += temp * pow(3, x);
    return stat;
}

LL dfs(int pos, int stat, int limit) {
    // cout << stat << endl;
    if (pos == -1) return judge(stat);
    if (!limit && dp[pos][stat] != -1) return dp[pos][stat];
    int up;
    if (limit)
        up = num[pos];
    else
        up = 9;
    LL res = 0;
    for (int i = 0; i <= up; i++) {
        if (stat == 0 && i == 0)//前导0 因为可以直接从stat的值看出来是否为前导0,所以不需要pre参数
            res += dfs(pos - 1, 0, limit&&i==num[pos]);
        else
            res += dfs(pos - 1, update(stat, i), limit&&i==num[pos]);
    }
    if (!limit) dp[pos][stat] = res;
    return res;
}

LL solve(LL n) {
    cnt = 0;
    while (n) {
        num[cnt++] = n % 10;
        n /= 10;
    }
    return dfs(cnt - 1, 0, 1);
}

int main() {
    int t;
    LL n, m;
    memset(dp, -1, sizeof dp);
    cin >> t;
    while (t--) {
        cin >> n >> m;
        cout << solve(m) - solve(n - 1) << endl;
    }
    return 0;
}
上一篇:stat()/lstat()函数使用


下一篇:Linux DIR,dirent,stat结构体