题目大意:
给你一些密码片段字符串,让你求长度为n,且至少包含k个不同密码片段串的字符串的数量。
题解:
因为密码串不多,可以考虑状态压缩
设dp[i][j][sta]表示长为i的字符串匹配到j节点且状态为sta的数量。
其中sta存储的是包含的密码串情况,在构建fail指针时,当前节点要并上fail指针所指的节点。
跑ac自动机,儿子节点从父亲节点转移。
最后取dp[len][...][sta]的和,其中sta满足二进制中1的数量>=k,
这一点可以像树状数组的lowbit那样快速求出:
inline int count(int x){
int ret = 0;
while(x){
ret++;
x -= (x & -x);
}
return ret;
}
code
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<algorithm>
#include<queue>
using namespace std;
const int N = 20, L = 20, Mod = 20090717;
int n, m, k, tot;
long long dp[30][110][1100], ans;
char s[20];
queue<int> que;
struct node{
int trans[27];
int fail, no;
int state;
inline void clear(){
memset(trans, 0, sizeof trans);
fail = state = no = 0;
}
}trie[1010];
inline int getVal(char st){
return st - 'a' + 1;
}
inline void insert(int num){
int len = strlen(s + 1), pos = 1;
for(int i = 1; i <= len; i++){
int val = getVal(s[i]);
if(!trie[pos].trans[val])
trie[trie[pos].trans[val] = ++tot].clear();
pos = trie[pos].trans[val];
}
trie[pos].state |= 1 << num;
}
inline void buildFail(){
for(int i = 1; i <= 26; i++) trie[0].trans[i] = 1;
que.push(1);
while(!que.empty()){
int u = que.front(); que.pop();
for(int i = 1; i <= 26; i++){
int v = trie[u].fail;
while(!trie[v].trans[i]) v = trie[v].fail;
int w = trie[u].trans[i];
v = trie[v].trans[i];
if(w){
trie[w].fail = v;
que.push(w);
trie[w].state |= trie[v].state;
}
else trie[u].trans[i] = v;
}
}
}
inline int count(int x){
int ret = 0;
while(x){
ret++;
x -= (x & -x);
}
return ret;
}
inline void solve(){
memset(dp, 0, sizeof dp);
int limit = 1 << m;
dp[0][1][0] = 1;
for(int i = 1; i <= n; i++)
for(int j = 1; j <= tot; j++)
for(int sta = 0; sta < limit; sta++)
if(dp[i - 1][j][sta])
for(int l = 1; l <= 26; l++){
int u = trie[j].trans[l];
dp[i][u][sta | trie[u].state] = (dp[i][u][sta | trie[u].state] + dp[i - 1][j][sta]) % Mod;
}
for(int i = 1; i <= tot; i++)
for(int sta = 0; sta < limit; sta++){
if(count(sta) >= k)
ans = (ans + dp[n][i][sta]) % Mod;
}
}
int main(){
while(scanf("%d%d%d", &n, &m, &k), n + m + k){
trie[tot = 1].clear(); ans = 0;
for(int i = 1; i <= m; i++){
scanf("%s", s + 1);
insert(i - 1);
}
buildFail();
solve();
cout << ans << endl;
}
}