前期#6
A
考虑直接数位dp即可,顺带记录往后有多少贡献即可,一些讨论可以具体看代码。
前期#6 A
#include<bits/stdc++.h>
#define ll long long
#define N 305
#define mod 998244353
int f[2][N][N][2];//枚举第几位,往后有几位,已经匹配了几个,是否顶到头
char s[N];
char lim[N];
int nex[N];
int to[N][10];
int m;
int len;
inline void kmp(){
int n = strlen(s + 1);
len = strlen(s + 1);
for(int i = 1;i <= n;++i)
s[i] = s[i] - '0';
int j = 0;
for(int i = 2;i <= n;++i){
while(j && s[j + 1] != s[i])
j = nex[j];
if(s[j + 1] == s[i])
j ++ ;
nex[i] = j;
// std::cout<<i<<" "<<nex[i]<<"\n";
}
for(int i = 0;i <= n;++i){
for(int j = 0;j < 10;++j)
if(s[i + 1] == j)
to[i][j] = i + 1;
else{
int k = i;
while(s[k + 1] != j && k)
k = nex[k];
if(s[k + 1] == j)
to[i][j] = k + 1;
}
// std::cout<<i<<"\n";
// for(int j = 0;j < 10;++j)
// std::cout<<to[i][j]<<" ";
// puts("");
}
}
inline void solve(){
f[0][0][0][1] = 1;
int n = strlen(lim + 1);
for(int i = 1;i <= n;++i)
lim[i] = lim[i] - '0';
for(int i = 1;i <= n;++i){
int now = i & 1;
int las = (i - 1) & 1;
// std::cout<<i<<"\n";
for(int j = 0;j <= n;++j)
for(int k = 0;k <= n;++k)
for(int q = 0;q <= 1;++q)
f[now][j][k][q] = 0;
for(int j = 0;j <= n;++j)
for(int k = 0;k <= n;++k)
for(int q = 0;q <= 1;++q){
if(f[las][j][k][q]){
// std::cout<<"USE"<<" "<<i - 1<<" "<<j<<" "<<k<<" "<<q<<" "<<f[las][j][k][q]<<"\n";
int tj ;
int tk ;
int tq ;
if(q){
for(int c = 0;c < lim[i];++c){
tq = 0;
int t = to[j][c];
if(t == len){
tk = k + 1;
tj = nex[len];
}else{
tk = k;
tj = t;
}
f[now][tj][tk][tq] = (f[now][tj][tk][tq] + f[las][j][k][q]) % mod;
// std::cout<<"CHOOSE "<<c<<" "<<tj<<" "<<tk<<" "<<tq<<"\n";
}
int t = to[j][lim[i]];
tq = 1;
if(t == len){
tk = k + 1 ;
tj = nex[len];
}else{
tk = k;
tj = t;
}
// std::cout<<"CHOOSE "<<(int)lim[i]<<" "<<tj<<" "<<tk<<" "<<tq<<"\n";
f[now][tj][tk][tq] = (f[now][tj][tk][tq] + f[las][j][k][q]) % mod;
}else{
for(int c = 0;c <= 9;++c){
tq = 0;
int t = to[j][c];
if(t == len){
tk = k + 1;
tj = nex[len];
}else{
tk = k;
tj = t;
}
f[now][tj][tk][tq] = (f[now][tj][tk][tq] + f[las][j][k][q]) % mod;
// std::cout<<"CHOOSE "<<c<<" "<<tj<<" "<<tk<<" "<<tq<<"\n";
}
}
}
}
}
ll ans = 0;
for(int j = 0;j <= n;++j)
for(int k = 0;k <= n;++k)
for(int q = 0;q <= 1;++q){
if(k >= m){
ans = (ans + f[n & 1][j][k][q]) % mod;
}
}
std::cout<<ans<<"\n";
}
int main(){
// freopen("p.out","w",stdout);
scanf("%d",&m);
scanf("%s",lim + 1);
scanf("%s",s + 1);
kmp();
solve();
}