题目大意
试求出一个字符串每一个长度为偶数的后缀在原字符串中出现的次数。
解题思路
比较简单。
对这个字符串建 AC
自动机,然后建上 fail
树。
那么一个长度为偶数的前缀在原字符串中出现的次数就是这个前缀在 Trie
上的结束节点在 fail
上的子树和。
也可以优化,意义是一样的。
AC CODE
考场代码。
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define _ 500000
int ans;
char str[_];
int cnt, tr[_][27], tag[_], fail[_];
void insert(char *s)
{
int p = 0;
int len = strlen(s + 1);
for(int i = 1; i <= len; ++i)
{
int v = s[i] - 'a';
// cout << v << endl;
if(!tr[p][v]) tr[p][v] = ++cnt;
p = tr[p][v];
if(i % 2 == 0) tag[p] = 1;
}
}
void getfail()
{
queue<int> q;
for(int i = 0; i < 26; ++i)
{
if(tr[0][i])
{
fail[tr[0][i]] = 0;
q.push(tr[0][i]);
}
}
while(!q.empty())
{
int u = q.front();
q.pop();
for(int i = 0; i < 26; ++i)
{
if(tr[u][i])
{
fail[tr[u][i]] = tr[fail[u]][i];
q.push(tr[u][i]);
}
else
{
tr[u][i] = tr[fail[u]][i];
}
}
tag[u] += tag[fail[u]];
}
}
signed main()
{
scanf("%s", str + 1);
insert(str);
getfail();
for(int i = 0; i <= cnt; ++i) ans += tag[i];
printf("%lld\n", ans);
return 0;
}
便于理解的代码。
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define _ 500000
int ans;
char str[_];
int cnt, tr[_][27], tag[_], fail[_];
void insert(char *s)
{
int p = 0;
int len = strlen(s + 1);
for(int i = 1; i <= len; ++i)
{
int v = s[i] - 'a';
// cout << v << endl;
if(!tr[p][v]) tr[p][v] = ++cnt;
p = tr[p][v];
if(i % 2 == 0) tag[p] = 1;
}
}
int tot, head[_], to[_ << 1], nxt[_ << 1];
void add(int u, int v)
{
to[++tot] = v;
nxt[tot] = head[u];
head[u] = tot;
}
void getfail()
{
queue<int> q;
for(int i = 0; i < 26; ++i)
{
if(tr[0][i])
{
fail[tr[0][i]] = 0;
q.push(tr[0][i]);
}
}
while(!q.empty())
{
int u = q.front();
q.pop();
for(int i = 0; i < 26; ++i)
{
if(tr[u][i])
{
fail[tr[u][i]] = tr[fail[u]][i];
q.push(tr[u][i]);
}
else
{
tr[u][i] = tr[fail[u]][i];
}
}
// tag[u] += tag[fail[u]];
}
for(int i = 1; i <= cnt; ++i)
add(fail[i], i);
}
int siz[_];
void query(char *s)
{
int p = 0;
int len = strlen(s + 1);
for(int i = 1; i <= len; ++i)
{
int v = s[i] - 'a';
p = tr[p][v];
if(i % 2 == 0)
{
ans += siz[p];
}
}
}
void dfs(int u, int fa)
{
siz[u] = 1;
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
dfs(v, u);
siz[u] += siz[v];
}
}
signed main()
{
scanf("%s", str + 1);
insert(str);
getfail();
dfs(0, -1);
query(str);
printf("%lld\n", ans);
return 0;
}