给你 \(n\) 个字符串,对每个字符串,你可以删除其任意一个字符或让其保持原样,求最后使得字符串字典序不降得方案数。
不难想到一个 \(\mathcal{O}(n(\sum |S|)^2)\) ,我们定义状态 \(f_{i,j}\) 表示前 \(i\) 个字符串,结尾的字符串是第 \(i\) 个字符串删除第 \(j\) 个字符的方案数,如果不删除则 \(j = n + 1\) 。
手算一下不难发现 \(f_{i,j}\) 是 \(f_{i-1}\) 中删除 \(k\) 后比当前串删除 \(j\) 后小的位置的 \(f_{i-1,k}\) 之和。如果对于 \(j,k\) 按照删除对应位置后字符串的字典序排序,那么转移一定是一段前缀和。
对这些位置排序可以做到线性,具体做法见这道题。
排序后,如果我们用双指针扫一遍,我们一共需要比较 \(\mathcal{O}(\sum |S|)\) 次 \(S_{i,j}\) 和 \(S_{i-1,k}\),如果直接比较时间复杂度是 \(\mathcal{O}((\sum |S|)^2)\) 。
我们可以用哈希加速这个过程,二分出两个串第一个不同的位置即可。这样时间复杂度是 \(\mathcal{O}(\sum|S|\log )\) 的。
细节非常多,这里就不展开讨论。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define pre(i,a,b) for(int i=a;i>=b;i--)
const int N = 1000005, P = 1000000007, Q = 777777773, bas = 233;int pw[N];
void calc(char *s,int *u,int n){
int cur = 1, L = 1, R = n;
rep(i, 1, n)
if(i == n){
pre(j, n, R + 1)u[j + 1] = u[j];
while(cur)cur--, u[L++] = i - cur;
u[L] = n + 1;
}
else if(s[i] == s[i + 1])cur++;
else {
if(s[i] < s[i + 1]){
while(cur)cur--, u[R--] = i - cur;
}
else{
while(cur)cur--, u[L++] = i - cur;
}
cur = 1;
}
}
void init(char *s,int *u,int n){
u[0] = 0;
rep(i, 1, n)u[i] = (1LL * u[i - 1] * bas + s[i]) % Q;
}
int get(int *u,int x,int ban){
if(x < ban)return u[x];
int res = x - ban + 1;
return (1LL * u[ban - 1] * pw[res] % Q + u[x + 1] - 1LL * u[ban] * pw[res] % Q + Q)% Q;
}
bool cmp(char *s,char *t,int *u,int *v,int x,int y,int lenu,int lenv){
int ed = ~0,l = 1, r = std::min(lenu - (x <= lenu), lenv - (y <= lenv));
while(l <= r){
int mid = (l + r) >> 1;
if(get(u, mid, x) == get(v, mid, y))l = mid + 1;
else r = mid - 1, ed = mid;
}
if(~ed)return s[ed + (ed >= x)] < t[ed + (ed >= y)];
else return lenu - (x <= lenu) <= lenv - (y <= lenv);
}
int n,f[N],u[N],v[N],g[N],hs1[N],hs2[N];char s[N],t[N];
int main(){
pw[0] = 1;rep(i, 1, N - 5)pw[i] = 1LL * pw[i - 1] * bas % Q;
scanf("%d",&n);
scanf("%s",s + 1);
int w = strlen(s + 1);
calc(s, u, w);init(s, hs1, w);
rep(i, 1, w + 1)f[i] = 1;
rep(op, 2, n){
scanf("%s", t + 1);
int m = strlen(t + 1), j = 1, sum = 0;
calc(t, v, m);init(t, hs2, m);
rep(i, 1, m + 1)g[i] = 0;
rep(i, 1, m + 1){
while(j <= w + 1 && cmp(s, t, hs1, hs2, u[j], v[i], w, m))sum = (sum + f[j++]) % P;
g[i] = sum;
}
w = m;rep(i, 1, m + 1)f[i] = g[i], s[i] = t[i], u[i] = v[i], hs1[i] = hs2[i];
}
int ans = 0;
rep(i, 1, w + 1)ans = (ans + f[i]) % P;
printf("%d\n",ans);
return 0;
}