Desprition
给出一个只由字母 \(A\),\(G\),\(T\),\(C\) 组成的字符串 \(S\) ,长度为 \(n\) ,对于每个 \(i\) \(\in\) \([0,n]\),问有多少个长度为 \(m\),仅含有 \(A\),\(G\),\(T\),\(C\) 的字符串 \(T\) 使得 \(S\) 与 \(T\) 的最长公共子序列长度为 \(i\) 。
Solution
先研究 \(LCS\) 的转移柿子
\[ lcs[i][j] = \begin{cases}lcs[i - 1][j - 1] + 1 (S[i] == T[j])\\ max(lcs[i - 1][j],lcs[i][j - 1])\end{cases} \]得到结论:
当 \(i\) 固定时,
\[f[i][j - 1] \le f[i][j] \le f[i - 1][j] + 1 \]也就是,\(lcs[i][j]\) 和 $lcs[i][j - 1] 最多相差 \(1\) 且满足单调不减。
因此我们可以使用差分,又因为 \(\left|\ S \right| <= 15\),可以直接把差分后的\(lcs\) 状压起来。
定义 \(f[i][j]\) 为当 \(lcs\) 状态为 \(i\) 时加上 字符\(j\)后 的状态, \(dp[i][j]\) 长度为 \(i\) 时,状态为 \(j\) 的方案数。
首先预处理出 \(f[i][j]\) , 具体注释在代码里。
重点在于 \(dp\) 柿子的推导, 其实很简单啊。
你想嘛, \(dp[i][j]\) 的定义长度为 \(i\) 时,状态为 \(j\) 的方案数, 那肯定是从 \(dp[i - 1][k]\) 转移过来的。
那 \(k\) 怎么确定呢???
前面的 \(f[][]\) 不就是用来干这件事的吗?
直接枚举当前长度 \(i\), 长度为 \(i - 1\) 时的状态 \(j\) 以及 第 \(i\) 位的情况 \(k\)
那么当前状态就应该是 \(f[j][k]\) —— 在 \(j\) 后增加 字符\(k\) 的状态。
那么就可以得出式子啦~
\[ dp[i][f[j][k = dp[i][f[j][k]] + dp[i - 1][j]; \]最后答案的计算,因为将 \(lcs\) 差分了,所以 \(lcs\) 的实际长度就是 状压状态下 \(i\) 中 \(1\) 的个数。
另外的小细节在代码中啦~
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define il inline
const int N = (1 << 15) + 5, M = 1e3 + 5, mod = 1e9 + 7;
int T, n, m, num, t, a[20], f[N][5], dp[M][N], g[2][20], ans[20];
char s[N];
il int cnt(int x) {
int res = 0;
while(x) {
res += x & 1;
x >>= 1;
}
return res;
}
il int solve(int x,int y) {
int res = 0;
memset(g,0,sizeof(g));
for(int i = 0; i < n; i ++) g[0][i + 1] = g[0][i] + ((x >> i) & 1);
for(int i = 1; i <= n; i ++) {
if(a[i] == y) g[1][i] = g[0][i - 1] + 1;
g[1][i] = max(max(g[1][i],g[0][i]),g[1][i - 1]);
}
for(int i = 0; i < n; i ++) res += (1 << i) * (g[1][i + 1] - g[1][i]);
return res;
}
il void read(int &x) {
x = 0; char s = getchar();
while(s < '0' || s > '9') s = getchar();
while(s <= '9' && s >= '0') x = x * 10 + s - '0', s = getchar();
}
il void write(int x) {
if(x < 0) x = -x, putchar('-');
if(x > 9) write(x / 10), x %= 10;
putchar(x + '0');
}
int main() {
read(T);
while(T--) {
memset(ans,0,sizeof(ans));
memset(dp,0,sizeof(dp));
scanf("%s",s + 1), read(m);
n = strlen(s + 1), num = 1 << n;
for(int i = 1; i <= n; i ++) {
if(s[i] == 'A') a[i] = 1;
else if(s[i] == 'G') a[i] = 2;
else if(s[i] == 'T') a[i] = 3;
else a[i] = 4;
}
for(int i = 0; i < num; i ++) {
for(int j = 1; j <= 4; j ++) f[i][j] = solve(i,j);
}
dp[0][0] = 1;
for(int i = 1; i <= m; i ++) {
for(int j = 0; j < num; j ++) {
for(int k = 1; k <= 4; k ++) dp[i][f[j][k]] = (dp[i][f[j][k]] + dp[i - 1][j]) % mod;
}
}
for(int i = 0; i < num; i ++) {
t = cnt(i);
ans[t] = (ans[t] + dp[m][i]) % mod;
}
for(int i = 0; i <= n; i ++) write(ans[i]), putchar('\n');
}
return 0;
}