【HDU-4436】str2int(广义后缀自动机)

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4436

题目大意

给你若干个串,这些串由数字组成,求所有本质不同的串,转换成int型后,求和,对2012取模。

思路&知识点

多个串+求本质不同的串=广义后缀自动机!

然后就能构建出相应的自动机来了,接下来在图上搜索,对于节点\(1\)来说,其不能走\(0\)的边。(因为没有前导\(0\))

对自动机上的节点进行排序,然后从小到大按拓扑序进行\(dp\)。

对于一个节点和其能到达的节点,记为\(u\)和\(v\),记\(sum[u]\)为以当前点为终点,从节点\(1\)出发,所有转为int型之和;记\(ways[u]\)为,从节点\(1\)出发,能够到达当前点的路径数。

从节点\(1\)出发,其能走到的除了\(0\)之外的边,能够到达的路径数初始为\(1\),权值之和初始为边的值。

则推得转移式:\(sum[v] = \sum_{u->v}^{}sum[u]\ast 10 + j \ast ways[u]\)
其中,\(j\)为边的权值,即\(0,1,2,3...\)

注意此处不能单纯地用\(maxlen[i]-maxlen[link[i]]\)来作为\(ways[u]\),因为会出现前导\(0\)和无前导\(0\)的冲突。
以下样例能够很好地说明这个问题:好心的出题人

2
101
123

AC代码

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
#define inf_int 0x3f3f3f3f
#define inf_ll 0x3f3f3f3f3f3f3f3f
const int MAXN = 2e5 + 5;
const int MAXC = 10;
const int mod = 2012;

class Suffix_Automaton {
public:
    int rt, link[MAXN], maxlen[MAXN], trans[MAXN][MAXC];
    int sum[MAXN], ways[MAXN];

    void init() {
        rt = 1;
        link[1] = maxlen[1] = 0;
        memset(trans[1], 0, sizeof(trans[1]));
        memset(sum, 0, sizeof(sum));
        memset(ways, 0, sizeof(ways));
    }

    Suffix_Automaton() { init(); }

    inline int insert(int ch, int last) {   // main: last = 1
        if (trans[last][ch]) {
            int p = last, x = trans[p][ch];
            if (maxlen[p] + 1 == maxlen[x]) return x;
            else {
                int y = ++rt;
                maxlen[y] = maxlen[p] + 1;
                for (int i = 0; i < MAXC; i++) trans[y][i] = trans[x][i];
                while (p && trans[p][ch] == x) trans[p][ch] = y, p = link[p];
                link[y] = link[x], link[x] = y;
                return y;
            }
        }
        int z = ++rt, p = last;
        memset(trans[z], 0, sizeof(trans[z]));
        maxlen[z] = maxlen[last] + 1;
        while (p && !trans[p][ch]) trans[p][ch] = z, p = link[p];
        if (!p) link[z] = 1;
        else {
            int x = trans[p][ch];
            if (maxlen[p] + 1 == maxlen[x]) link[z] = x;
            else {
                int y = ++rt;
                maxlen[y] = maxlen[p] + 1;
                for (int i = 0; i < MAXC; i++) trans[y][i] = trans[x][i];
                while (p && trans[p][ch] == x) trans[p][ch] = y, p = link[p];
                link[y] = link[x], link[z] = link[x] = y;
            }
        }
        return z;
    }


    int topo[MAXN], topo_id[MAXN];

    int solve() {
        memset(topo, 0, sizeof(topo));
        // get topo index
        int ans = 0;
        for (int i = 1; i <= rt; i++) topo[maxlen[i]]++;
        for (int i = 1; i <= rt; i++) topo[i] += topo[i - 1];
        for (int i = 1; i <= rt; i++) topo_id[topo[maxlen[i]]--] = i;
        sum[1] = 1;
        for (int i = 1; i <= rt; i++) {
            int u = topo_id[i];
          //  printf("u = %d\n", u);
            for (int j = 0; j < MAXC; j++) {
                if (u == 1 && j == 0) continue;
                int v = trans[u][j];
                if (v) {
                    if (u == 1) sum[v] = j, ways[v] = 1;
                    // printf("(maxlen[%d] - maxlen[link[%d]]) = %d\n",u, u, (maxlen[u] - maxlen[link[u]]));
                    else {
                        ways[v] += ways[u];
                        sum[v] = (sum[v] + sum[u] * 10 % mod + j * ways[u] % mod) % mod;
                    }
                }
            }
            //printf("sum[%d] = %d\n", u, sum[u]);
            ans = (ans + sum[u]) % mod;
        }
        return (ans-1+mod)%mod;
    }

    void debug() {
        for (int i = 1; i <= rt; i++) {
            for (int j = 0; j < MAXC; j++) {
                if (trans[i][j]) {
                    printf("trans[%d][%d] = %d\n",i, j, trans[i][j]);
                }
            }
        }
    }

} sa;

char str[MAXN];

int main() {
    int n;
    while (~scanf("%d", &n)) {
        sa.init();
        for (int i = 1; i <= n; i++) {
            scanf("%s", str + 1);
            int last = 1;
            int len = strlen(str + 1);
            for (int j = 1; j <= len; j++) {
                last = sa.insert(str[j] - '0', last);
            }
        }
        printf("%d\n", sa.solve());
    }
}
上一篇:[LOJ171] 最长公共子串 - 后缀自动机


下一篇:shape_trans算子