链接:https://www.nowcoder.com/acm/contest/71/B
来源:牛客网
题目描述
设s,t为两个字符串,定义f(s,t) = t的子串中,与s相等的串的个数。如f("ac","acacac")=3, f("bab","babab")=2。现在给出n个字符串,第i个字符串为si。你需要对,求出,由于答案很大,你只需要输出对 998244353取模后的结果。
输入描述:
第一行一个整数n。
接下来n行每行一个仅由英文字母构成的非空字符串,第i个字符串代表s
i
。
输出描述:
共n行,第i行输出
对 998244353取模的结果。
输入例子:
1
BALDRSKYKirishimaRain
输出例子:
1
-->
示例1
输入
1
BALDRSKYKirishimaRain
输出
1
备注:
1 ≤ n ≤ 10^6,所有字符串的总长度不超过2*10^6
题解
$kmp$。
先观察一下要求的那个东西,注意是乘积,有一个为$0$即为$0$。
也就是说,只有长度最短的那些可能不为$0$,其余的一定为$0$。
长度最短的那些不全等,那么也是$0$,否则和别的都计算一遍即可。
#include <bits/stdc++.h>
using namespace std; const long long mod = 998244353LL;
const long long modh = 1e9 + 7;
const long long base = 131LL;
const int maxn = 1e6 + 10;
string s[maxn];
int len[maxn];
long long ans[maxn];
long long h[maxn];
int n; char S[2 * maxn], T[2 * maxn];
int slen, tlen;
int nx[2 * maxn]; void getNext()
{
int j, k;
j = 0; k = -1; nx[0] = -1;
while(j < tlen)
if(k == -1 || T[j] == T[k])
nx[++j] = ++k;
else
k = nx[k];
} int kmp(string &a) {
slen = a.length();
for(int i = 0; i < slen; i ++) {
S[i] = a[i];
S[i + 1] = 0;
}
int aa = 0;
int i, j = 0;
if(slen == 1 && tlen == 1)
{
if(S[0] == T[0])
return 1;
else
return 0;
}
for(i = 0; i < slen; i++)
{
while(j > 0 && S[i] != T[j])
j = nx[j];
if(S[i] == T[j])
j++;
if(j == tlen)
{
aa++;
j = nx[j];
}
}
return aa;
} int main() {
scanf("%d", &n);
int mn_len = 2 * maxn;
for(int i = 1; i <= n; i ++) {
cin >> s[i];
len[i] = s[i].length();
for(int j = 0; j < len[i]; j ++) {
h[i] = h[i] * base % mod;
h[i] = (h[i] + s[i][j]) % mod;
}
mn_len = min(mn_len, len[i]);
}
int idx;
for(int i = 1; i <= n; i ++) {
if(len[i] == mn_len) idx = i;
}
int fail = 0;
for(int i = 1; i <= n; i ++) {
if(len[i] == mn_len) {
if(h[i] != h[idx]) fail = 1;
}
} if(fail == 0) {
for(int i = 0; i < len[idx]; i ++) {
T[i] = s[idx][i];
T[i + 1] = 0;
}
tlen = len[idx];
getNext();
long long A = 1;
for(int i = 1; i <= n; i ++) {
if(len[i] == mn_len) continue;
A = A * kmp(s[i]) % mod;
if(A == 0) break;
}
for(int i = 1; i <= n; i ++) {
if(len[i] == mn_len) ans[i] = A;
}
} for(int i = 1; i <= n; i ++) {
printf("%lld\n", ans[i]);
} return 0;
}