H 洪尼玛的保险箱
题意:
求两个字符串的公共子串个数,且该公共子串在第一个串中以奇数位置结尾,在第二个串中以偶数位置结尾
思路:
建广义后缀自动机后根据限制条件求出每个状态的\(|endpos|\),统计答案即可
#include<cstring>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<vector>
using namespace std;
typedef long long ll;
const int Maxn = 400010;
const int N = 2;
struct Suffix_Automata{
int maxlen[Maxn], trans[Maxn][26], link[Maxn], Size;
int sz[Maxn][N];
void clr(int x){
maxlen[x] = link[x] = 0;
memset(trans[x], 0, sizeof(trans[x]));
memset(sz[x], 0, sizeof(sz[x]));
}
void init(){
Size = 1;
clr(1);
}
int insert(int ch, int last, int id, int idx){
if (trans[last][ch]){
int p = last, x = trans[p][ch];
if (maxlen[p] + 1 == maxlen[x]){
if(id ^ idx)
sz[x][id] = 1;
return x;
}else{
int y = ++Size;
clr(y);
maxlen[y] = maxlen[p] + 1;
for (int i = 0; i < 26; ++i)
trans[y][i] = trans[x][i];
while (p && trans[p][ch] == x)
trans[p][ch] = y, p = link[p];
link[y] = link[x], link[x] = y;
if(id ^ idx)
sz[y][id] = 1;
return y;
}
}
int z = ++Size, p = last;
clr(z);
if(id ^ idx)
sz[z][id] = 1;
maxlen[z] = maxlen[last] + 1;
while (p && !trans[p][ch])
trans[p][ch] = z, p = link[p];
if (!p) link[z] = 1;
else{
int x = trans[p][ch];
if (maxlen[p] + 1 == maxlen[x]) link[z] = x;
else{
int y = ++Size;
clr(y);
maxlen[y] = maxlen[p] + 1;
for (int i = 0; i < 26; ++i)
trans[y][i] = trans[x][i];
while (p && trans[p][ch] == x)
trans[p][ch] = y, p = link[p];
link[y] = link[x], link[z] = link[x] = y;
}
}
return z;
}
int c[Maxn], a[Maxn];
void rsort(int n){
for(int i = 0; i <= n; ++i) c[i] = a[i] = 0;;
for (int i = 1; i <= Size; i++) c[maxlen[i]]++;
for (int i = 1; i <= n; i++) c[i] += c[i - 1];
for (int i = 1; i <= Size; i++) a[c[maxlen[i]]--] = i;
for (int i = Size; i >= 1; i--) sz[link[a[i]]][0] += sz[a[i]][0], sz[link[a[i]]][1] += sz[a[i]][1];
}
void solve(){
ll ans = 0;
for(int i = 1; i <= Size; ++i)
ans += 1ll * sz[i][0] * sz[i][1] * (maxlen[i] - maxlen[link[i]]);
printf("%lld\n", ans);
}
}sam;
char s1[100010], s2[100010];
int main(){
while(~scanf("%s%s", s1 + 1, s2 + 1)){
int n = strlen(s1 + 1), m = strlen(s2 + 1);
sam.init();
int last = 1;
for(int i = 1; i <= n; ++i) last = sam.insert(s1[i] - 'a', last, 0, i & 1);
last = 1;
for(int i = 1; i <= m; ++i) last = sam.insert(s2[i] - 'a', last, 1, i & 1);
sam.rsort(max(n, m));
sam.solve();
}
return 0;
}