Description
给定一个长度为$n$的字符串,串中的字符保证是前$k$个小写字母。你可以在字符串后再添加$m$个字符,使得新字符串所包含的不同的子序列数量尽量多。当然,前提是只能添加前$k$个小写字母。求新的长度为$n+m$的串最多的不同子序列数量。答案对$10^9+7$取模。
Input
输入第一行两个数$m,k$。
接下来一行一个字符串,长度为$n$,表示原始的字符串$s$。
Output
一个数,表示答案。
Sample Input
1 3
ac
Sample Output
8
HINT
$n,m\;\leq\;10^6,k\;\leq\;26$
Solution
当$m=0$时,
$lst[i]$表示字符$i$上一次出现的位置,
$f[i]$表示以第$i$位结尾的新出现的不同的子序列的个数.
以第$x(lst[i]\;\leq\;x<i)$位结尾的新出现的子序列末尾加上$s[i]$为一个新的子序列.(反证法可证$x(0<x<lst[i])$不可行)
$f[i]=\sum_{j=lst[s[i]]}^{i-1}f[j]$.
这个可以用前缀和优化.
当$m\not=0$时,
设$sum[i]=\sum_{j=1}^{i}f[j]$,
则$f[i]=sum[i-1]-sum[lst[j]-1](n<i\;\leq\;n+m)$
$f[i]$最大,即$lst[j]-1$最小.
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<cstdio>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define K 2000005
#define M 1000000007
using namespace std;
int s[K],sum,tmp,m,n,t;
bool u[K];char c;
int f[K],lst[K];
inline void Aireen(){
scanf("%d%d",&m,&t);
c=getchar();
for(n=1;scanf("%c",&c)==1;++n){
if(!(c>='a'&&c<='z'))
break;
if(lst[c-'a'])
f[n]=(s[n-1]-s[lst[c-'a']-1]+M)%M;
else f[n]=(s[n-1]+1)%M;
s[n]=(s[n-1]+f[n])%M;
lst[c-'a']=n;
}
--n;
if(t) for(int i=n+1,j,k;i<=n+m;++i){
k=lst[0];j=0;
for(int l=1;l<t;++l){
if(lst[l]<k){
k=lst[l];j=l;
}
}
printf("j=%d\n",j);
if(lst[j])
f[i]=(s[i-1]-s[lst[j]-1]+M)%M;
else f[i]=(s[i-1]+1)%M;
s[i]=(s[i-1]+f[i])%M;
lst[j]=i;
}
printf("%d\n",(s[n+m]+1)%M);
}
int main(){
freopen("string.in","r",stdin);
freopen("string.out","w",stdout);
Aireen();
fclose(stdin);
fclose(stdout);
return 0;
}
因为卡空间$1MB$,每次转移只与$f[lst[i]-1]$有关,所以只需$O(k)$的空间.
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<cstdio>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define K 26
#define M 1000000007
using namespace std;
int s[K],lst[K],sum,tmp,m,n,t;
char c;
inline void Aireen(){
scanf("%d%d",&m,&t);
c=getchar();
for(n=1;scanf("%c",&c)==1;++n){
if(!(c>='a'&&c<='z'))
break;
tmp=s[c-'a'];s[c-'a']=sum;
if(lst[c-'a']) sum=((sum<<1)%M-tmp+M)%M;
else sum=((sum<<1)+1)%M;
lst[c-'a']=n;
}
--n;
if(t) for(int i=1,j,k;i<=m;++i){
k=lst[0];j=0;
for(int l=1;l<t;++l){
if(lst[l]<k){
k=lst[l];j=l;
}
}
tmp=s[j];s[j]=sum;
if(lst[j]) sum=((sum<<1)%M-tmp+M)%M;
else sum=((sum<<1)+1)%M;
lst[j]=i+m;
}
printf("%d\n",(sum+1)%M);
}
int main(){
freopen("string.in","r",stdin);
freopen("string.out","w",stdout);
Aireen();
fclose(stdin);
fclose(stdout);
return 0;
}