【牛客网】数一数
题目大意
- 定义 \(\operatorname{f}(s,t)\) 为 \(t\) 的子串中,与 \(s\) 相等的串的个数。
- 给出 \(n\) 个字符串,第 \(i\) 个字符串为 \(s_i\),对 \(\forall1\leq i\leq n\),求出 \(\prod_{j=1}^n\operatorname{f}(s_i,s_j)\)。
- 答案对 \(998244353\) 取模。
- \(1\leq n\leq10^6\),所有字符串的总长度不超过 \(2\times10^6\)。
题解
首先,我们可以发现只要当前字符串和其他一个字符串求出的答案为 \(0\),那么这个字符串的答案就为 \(0\)。
然后,我们发现一个长度较大的字符串和一个长度较小的字符串求出的答案肯定是 \(0\)。
所以我们可以先按字符串长度来排序,找到一个(或多个)长度最小的字符串,然后跑 \(n\) 遍 KMP,剩下不是最短的字符串答案均为 \(0\),最后输出答案。
那么如何处理有多个最短的字符串的情况呢?
我们可以比较所有最短的的字符串,如果其中有不一样的就说明答案均为零,否则答案都是一样的,只要跑 \(n\) 遍 KMP。
代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <vector>
using namespace std;
const int N = 2000005, mod = 998244353;
vector<int> nxt;
struct str
{
string s;
int len, num;
bool b;
} a[1000005];
int ans = 1, sum = 1, n, res[1000005];
void getnxt(string &substr)
{
nxt.clear();
nxt.resize(substr.size());
int j = -1;
nxt[0] = -1;
for (int i = 1; i < substr.size(); ++i)
{
j = nxt[i - 1];
while (j > -1 && substr[j + 1] != substr[i])
j = nxt[j];
if (substr[j + 1] == substr[i])
++j;
nxt[i] = j;
}
}
int kmp(string &str, string &substr)
{
int cnt = 0;
int j = -1;
for (int i = 0; i < str.length(); ++i)
{
while (j > -1 && substr[j + 1] != str[i])
j = nxt[j];
if (substr[j + 1] == str[i])
++j;
if (j == substr.length() - 1)
{
++cnt;
j = nxt[j];
}
}
return cnt;
}
bool compare(string s1, string s2)
{
int n1 = s1.length();
for (int i = 0; i < n1; i++)
{
if (s1[i] != s2[i])
{
return false;
}
}
return true;
}
bool cmp(str a, str b)
{
return a.len < b.len;
}
void print()
{
for (int i = 1; i <= n; i++)
{
cout << 0 << endl;
}
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++)
{
cin >> a[i].s;
a[i].len = a[i].s.length();
a[i].num = i;
}
if (n == 1)
{
cout << 1;
return 0;
}
sort(a + 1, a + n + 1, cmp);
for (int i = 1; i < n; i++)
{
if (a[i].len == a[i + 1].len)
{
if (!compare(a[i].s, a[i + 1].s))
{
print();
return 0;
}
else
{
sum++;
}
}
else
{
break;
}
}
getnxt(a[1].s);
for (int i = 2; i <= n; i++)
{
ans = 1LL * ans * kmp(a[i].s, a[1].s) % mod;
}
for (int i = 1; i <= sum; i++)
{
a[i].b = 1;
}
for (int i = 1; i <= n; ++i)
{
if (a[i].len == a[1].len)
res[a[i].num] = ans;
else
res[a[i].num] = 0;
}
for (int i = 1; i <= n; i++)
{
printf("%d\n", res[i]);
}
return 0;
}