题意:
定义DNA序列为仅有ATCG
四个字母构成的字符串。现在需要构造一个长度为\(n\)的DNA序列\(w\),并给你\(m\)个DNA序列。对于\(w\)来说,每个字符\(w_i\)都能找到至少一对\(l\), \(r\),\(l \leq i \leq r\),使得[\(w_l\),\(\dots\),\(w_r\)]是给定的\(m\)个DNA序列中的一个DNA序列,输出能构造出来的\(w\)的数量,取模 \(10^9 + 9\)
思路:
转换下题意。构造一个文本串,需要使用若干个模式串(可使用若干次),重复字段可以直接放在一起,需要完全匹配,问你方案数。涉及文本串完全匹配多个模式串,往\(AC\)自动机上靠。按照套路,设\(dp[i][j]\)为在\(AC\)自动机上匹配到第\(j\)个节点,已经构造了\(w\)的前\(i\)位的方案数。但是要求完全覆盖,这样子并不能体现。我们考虑匹配到一个模式串的尾节点的时候,设模式串的长度为\(len\),那么我前面构造出来的\(len\)位都能被完全覆盖。所以我们字典树上在尾节点记录模式串的长度,\(dp\)数组多开一维k,记录之前构造的文本串的结尾还有\(k\)位没有匹配。转移时采用刷表法,用当前状态推到所有与当前状态相关联的状态。由于在自动机上转移,只要下一状态的长度大于\(k\),就是能将\(k\)变成\(0\)的合法转移,其余的情况则会让没有匹配的位数多一位。
注意在跑\(Fail\)指针的时候,要用\(Fail\)指针节点去更新当前节点\(len\)值的最大值。例如TTA
和T
这两个模式串,在树上是会重起来的,T
在转移的时候并不会呆在原地不动,而是会转移到TTA
的第二位上,而TTA
的第二个节点的\(len\)值是\(0\),并不能对\(dp[2][j][0]\)产生贡献,造成答案错误。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5e4 + 7;
const int mod = 1e9 + 9;
int get(char c) {
return c == ‘A‘? 0 : c == ‘T‘? 1 : c == ‘C‘? 2 : 3;
}
int n, m;
void add(int &a, int b) {
a += b;
if (a >= mod) a -= mod;
}
struct ACAM {
struct node {
int nx[4];
int fail, val;
void init() {
memset(nx, -1, sizeof(nx));
fail = val = 0;
}
}t[207];
int root, tot, maxLen;
int dp[1007][207][15];
int newnode() {
t[++tot].init();
return tot;
}
void init() {
tot = 0;
maxLen = 0;
root = newnode();
}
void insert(char *s) {
int len = strlen(s);
maxLen = max(maxLen, len);
int now = root;
for (int i = 0; i < len; ++i) {
int v = get(s[i]);
if (t[now].nx[v] == -1) t[now].nx[v] = newnode();
now = t[now].nx[v];
}
t[now].val = len;
}
void getFail() {
for (int i = 0; i < 4; ++i) t[0].nx[i] = 1;
queue<int>q;
q.push(1);
t[1].fail = 0;
while (!q.empty()) {
int u = q.front();
q.pop();
for (int i = 0; i < 4; ++i) {
int v = t[u].nx[i];
int Fail = t[u].fail;
if (v == -1) {
t[u].nx[i] = t[Fail].nx[i];
continue;
}
t[v].fail = t[Fail].nx[i];
t[v].val = max(t[v].val, t[t[v].fail].val);
q.push(v);
}
}
}
void DP() {
memset(dp, 0, sizeof(dp));
dp[0][1][0] = 1;
for (int i = 0; i < n; ++i) {
for (int j = 1; j <= tot; ++j) {
for (int k = 0; k < maxLen; ++k) {
for (int x = 0; x < 4; ++x) {
int v = t[j].nx[x];
if (v == -1) continue;
add(dp[i + 1][v][t[v].val > k? 0 : k + 1], dp[i][j][k]);
}
}
}
}
int ans = 0;
for (int j = 1; j <= tot; ++j) {
add(ans, dp[n][j][0]);
}
printf("%d\n", ans);
}
}acam;
char s[15];
void solve() {
scanf("%d %d", &n, &m);
acam.init();
for (int i = 1; i <= m; ++i) {
scanf("%s", s);
acam.insert(s);
}
acam.getFail();
acam.DP();
}
int main() {
solve();
return 0;
}