P3181 [HAOI2016]找相同字符(SAM的应用)

传送门

ps:实际上用后缀数组会比较简单…

对串 A A A建立 S A M SAM SAM,拿串 B B B放在串 A A A上跑

当我们加入 b [ i ] b[i] b[i]时,需要计算 [ 1 , i ] [1,i] [1,i]的所有后缀与串 A A A的贡献

设当前在 S A M SAM SAM跑到节点 j j j,匹配后缀长度是 l l l

Ⅰ.现在考虑节点 j j j的贡献

那么节点 j j j中的所有子串中,长度小于等于 l l l的都有贡献

其中有 l − l e n [ f a j ] l-len[fa_j] l−len[faj​]个子串是串 B B B的 [ 1 , i ] [1,i] [1,i]后缀,那么贡献是

s i z [ j ] ∗ ( l − l e n [ f a j ] ) siz[j]*(l-len[fa_j]) siz[j]∗(l−len[faj​])

因为节点 j j j包括的每个串都出现了 s i z [ j ] siz[j] siz[j]次

Ⅱ.考虑节点 j j j的祖先

节点 j j j的祖先都是节点 j j j的后缀,所以节点 j j j的祖先一定都有贡献,每个点的贡献是

s i z [ k ] ∗ ( l e n [ k ] − l e n [ f a k ] ) siz[k]*(len[k]-len[fa_k]) siz[k]∗(len[k]−len[fak​])

理解一下,就是节点 k k k包含 l e n [ k ] − l e n [ f a k ] len[k]-len[fa_k] len[k]−len[fak​]个串,这些串都在 B B B中,而且每个都在 A A A中出现 s i z [ k ] siz[k] siz[k]次

这部分我们预处理一下,提前把父亲的贡献累加到儿子上去,就可以 O ( 1 ) O(1) O(1)计算

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 4e5+10;
struct SAM
{
	int zi[26],len,fa;
	SAM(){ memset( zi,0,sizeof zi ); }
}sam[maxn]; int las = 1,id = 1, siz[maxn];
char s[maxn];
void insert(int c)
{
	int p = las, np = ++id; las = id;
	sam[np].len = sam[p].len+1; siz[np] = 1;
	for( ;p&&!sam[p].zi[c];p=sam[p].fa)	sam[p].zi[c] = np;
	if( p==0 )	{ sam[np].fa = 1; return; }
	int q = sam[p].zi[c];
	if( sam[q].len==sam[p].len+1 )	sam[np].fa = q;//连接一下,因为q的endpos包含n 
	else
	{
		int nq = ++id;
		sam[nq] = sam[q], sam[nq].len = sam[p].len+1;
		sam[q].fa = sam[np].fa = nq;
		for( ;p&&sam[p].zi[c]==q;p=sam[p].fa )	sam[p].zi[c] = nq;
	} 
}
int c[maxn],rk[maxn],f[maxn];
void tuopu()
{
	for(int i=1;i<=id;i++)	c[sam[i].len]++;
	for(int i=1;i<=id;i++)	c[i] += c[i-1];
	for(int i=1;i<=id;i++)	rk[c[sam[i].len]--] = i;
	for(int i=id;i>=1;i--)	siz[sam[rk[i]].fa] += siz[rk[i]];
	for(int i=2;i<=id;i++)
	{
		int u = rk[i], fa = sam[u].fa;
		f[u] = ( sam[u].len-sam[fa].len )*siz[u]+f[fa];
	}
}
signed main()
{
	scanf("%s",s); int n = strlen( s );
	for(int i=0;i<n;i++)	insert( s[i]-'a' );
	tuopu();
	scanf("%s",s); n = strlen( s );
	int ans = 0, p = 1,l = 0;
	for(int i=0;i<n;i++)
	{
		int u = s[i]-'a';
		while( p&&!sam[p].zi[u] )	p = sam[p].fa;
		if( !p )	p = 1,l = 0;
		else//找到了 
		{
			l = min( l,sam[p].len )+1;
			p = sam[p].zi[u]; int fa = sam[p].fa;
			ans += f[fa]+(l-sam[fa].len )*siz[p];
		}	
	}
	cout << ans;
}
上一篇:轻重链剖分 学习笔记


下一篇:树上DP5