HDU7055. Yiwen with Sqc
题意:
对于一个字符串\(s\),它的第\(i\)个字符表示为\(s_i\)
定义一个函数\(sqc(s,l,r,c)\),\(s\)代表字符串,\(1\leq l\leq r\leq n\)?,\(c\)?是一个字符,用“ASCII”码表示。这个函数返回\(s\)的子串\(s_l s_{l+1} \dots s_r\)?中\(c\)出现的次数
现要计算\(\sum\limits_{c=97}^{122} \sum\limits_{i=1}^{n} \sum\limits_{j=i}^{n}sqc^2(s,i,j,c)\)
分析:
考虑分别计算26个字母的贡献,不妨考虑字母\(a\)??
假设一共有\(cnt\)个\(a\)
设\(pos_i\)为第\(i\)个\(a\)?出现的位置,特别地,定义\(pos_0=0,pos_{cnt+1}=n+1\)
区间左端点\(l\)??????落在\((pos_i,pos_{i+1}]\)??????,当右端点\(r\)??落在\([pos_j,pos_{j+1})\)??中时,\([l,r]\)?中字母\(a\)?的数量是\(j-i\)????
那么总贡献应为
不妨设\(len_i=pos_{i+1}-pos_{i}\)?,则总贡献化为
假如我们能在\(O(1)\)???的时间内计算出\(sum_j=\sum\limits_{i=0}^{j-1}len_i(j-i)^2\)???,我们就能在\(O(cnt)\)???的时间内求出所有字母\(a\)????的贡献
为了方便,定义\(sum_0=0\),考虑差分\(d1_j=sum_j-sum_{j-1}=len_{j-1}+\sum\limits_{i=0}^{j-2}len_i(2j-2i-1)\)
有\(sum_j=sum_{j-1}+d1_j\),如果我们能\(O(1)\)求出\(d1_j\),那么通过前缀和即可\(O(1)\)算出\(sum_j\)
为了方便,定义\(d1_0=0,d1_1=len_0\)?,考虑再差分
有\(d1_j=d1_{j-1}+d2_j\)??,如果我们能\(O(1)\)??求出\(d2_j\)??,那么通过前缀和即可\(O(1)\)??算出\(d1_j\)??
为了方便,定义\(d2_0=0,d2_1=len_0,d2_2=len_1+2len_0\)?,考虑再差分
有\(d2_j=d2_{j-1}+d3_j\),如果我们能\(O(1)\)求出\(d3_j\),那么通过前缀和即可\(O(1)\)算出\(d2_j\)
为了方便,定义\(d3_0=0,d3_1=len_0,d3_2=len_0+len_1\)?
\(d3_j\)由\(len\)组成,而\(len\)显然可以通过\(pos\)来得到,\(pos\)可以事先预处理,所以我们可以\(O(cnt)\)的时间内求出所有字母\(a\)?的贡献。
将所有字母的贡献求和即为答案,预处理\(pos\)需要\(O(n)\),计算需要\(O(\sum cnt)\)即为\(O(n)\)?
代码:
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
typedef long long Lint;
const Lint mod = 998244353;
const int maxn = 1e5 + 10;
vector<Lint> pos[26];
vector<Lint> len[26];
char str[maxn];
void solve() {
scanf("%s", str + 1);
int n = strlen(str + 1);
for (int i = 1; i <= n; i++) {
if (pos[str[i] - ‘a‘].empty())
pos[str[i] - ‘a‘].push_back(0);
pos[str[i] - ‘a‘].push_back(i);
}
for (int i = 0; i < 26; i++) {
if (!pos[i].empty()) {
pos[i].push_back(n + 1);
len[i].resize(pos[i].size() - 1);
for (int j = 0; j < pos[i].size() - 1; j++) {
len[i][j] = pos[i][j + 1] - pos[i][j];
}
}
}
Lint res = 0;
for (int i = 0; i < 26; i++) {
if (pos[i].empty())
continue;
Lint ans = 0, sum = 0, d1 = 0, d2 = 0, d3;
for (int j = 1; j <= pos[i].size() - 2; j++) {
if (j == 1)
d3 = len[i][j - 1];
else
d3 = (len[i][j - 1] + len[i][j - 2]) % mod;
d2 = (d2 + d3) % mod;
d1 = (d1 + d2) % mod;
sum = (sum + d1) % mod;
ans = (ans + len[i][j] * sum % mod) % mod;
}
res = (res + ans) % mod;
}
for (int i = 0; i < 26; i++) {
pos[i].clear();
}
printf("%lld\n", res);
}
int main() {
int T;
scanf("%d", &T);
while (T--)
solve();
return 0;
}