这题想了好一会呢。。刚开始想错了,以为用自动机预处理出k长度可以包含的合法的数的个数,然后再数位dp一下就行了,写到一半发现不对,还要处理当前走的时候是不是为合法的,这一点无法移到trie树上去判断。
之后想到应该在trie树上进行数位dp,走到第i个节点且长度为j的状态是确定的,所以可以根据trie树上的节点来进行确定状态。
dp[i][j]表示当前节点为i,数第j位时可以包含多少个合法的数。
#include <iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<stdlib.h>
#include<vector>
#include<cmath>
#include<queue>
#include<set>
using namespace std;
#define N 2010
#define LL long long
#define INF 0xfffffff
const double eps = 1e-;
const double pi = acos(-1.0);
const double inf = ~0u>>;
const int child_num = ;
const int mod = ;
int dp[][N];
char s1[],s2[];
class AC
{
private:
int ch[N][child_num];
int Q[N];
int fail[N];
int val[N];
int id[];
int sz;
int dd[][N];
public:
void init()
{
fail[] = ;
id[''] = ;id[''] = ;
}
void reset()
{
memset(val,,sizeof(val));
memset(ch[],,sizeof(ch[]));
sz=;
}
void insert(char *a,int key)
{
int p = ;
for( ; *a ; a++)
{
int d = id[*a];
if(ch[p][d]==){
memset(ch[sz],,sizeof(ch[sz]));
ch[p][d] = sz++;
}
p = ch[p][d];
}
val[p] = key;
}
void construct()
{
int i,head=,tail = ;
for(i = ;i < child_num ; i++)
{
if(ch[][i])
{
fail[ch[][i]] = ;
Q[tail++] = ch[][i];
}
}
while(head!=tail)
{
int u = Q[head++];
val[u]|=val[fail[u]];
for(i = ;i < child_num ; i++)
{
if(ch[u][i])
{
fail[ch[u][i]] = ch[fail[u]][i];
Q[tail++] = ch[u][i];
}
else ch[u][i] = ch[fail[u]][i];
}
}
}
int dfs(char *s,int i,int c,int e,int k)
{
if(i==-)
{
return ;
}
if(!e&&~dp[i][c])
{
return dp[i][c];
}
int mk = e?s[i]-'':;
int ans = ;
for(int j = ; j <= mk ; j++)
{
if(!k&&j==&&i)
{
ans = (ans+dfs(s,i-,c,e&&j==mk,k));
continue;
}
int p = c,flag = ;
for(int g = ; g >= ; g--)
{
int o = (j&(<<g))?:;
p = ch[p][o];
int tmp = p;
while(tmp!=)
{
if(val[tmp])
{
flag = ;
break;
}
tmp = fail[tmp];
}
if(!flag) break;
}
if(flag)
{
ans = (ans+dfs(s,i-,p,e&&j==mk,))%mod;
}
}
return e?ans:dp[i][c] = ans;
}
void work(char *s1,char *s2)
{
memset(dp,-,sizeof(dp));
printf("%d\n",(dfs(s2,strlen(s2)-,,,)-dfs(s1,strlen(s1)-,,,)+mod)%mod);
}
}ac;
char vir[];
char ss1[],ss2[];
int main()
{
int t,n,i;
ac.init();
scanf("%d",&t);
while(t--)
{
ac.reset();
scanf("%d",&n);
while(n--)
{
scanf("%s",vir);
ac.insert(vir,);
}
ac.construct();
scanf("%s%s",s1,s2);
int k = strlen(s1),kk= strlen(s2);
for(i = k- ; i >= ; i--)
{
if(s1[i]>'')
{
s1[i]-=;
break;
}
else
s1[i] = '';
}
for(i = ; i < k ; i++)
ss1[k--i] = s1[i];
ss1[k] = '\0';
for(i = ; i < kk ; i++)
ss2[kk--i] = s2[i];
ss2[kk] = '\0';
ac.work(ss1,ss2);
}
return ;
}