luogu P5287 [HNOI2019]JOJO

传送门

神™这题暴力能A,这出题人都没造那种我考场就想到的数据,难怪我的垃圾做法有分

先考虑没有撤销操作怎么做,因为每次插入一段一样的字符,所以我们可以把\(x\)个字符\(c\)定义为\(cx\),然后用这种新字符做\(\mathrm{kmp}\).但是直接把一般的\(\mathrm{kmp}\)搬过来做是错的.例如\(yybbbyybb\),最后一个b的\(next\)是第二个b,但是题目有个限制,每次往后加的字符不会和上一个字符相同,那么现在往后加任何字符,因为都不等于b,所以就一定无法继续匹配.所以,新的\(\mathrm{kmp}\)的\(next\)指向的字符必须和当前字符的长度和字符类型要一致,这样才能继续往后接东西.还有一种特殊情况,如果当前字符的类型和第一个相同,但是长度比第一个长,那么\(next\)应该指向1,因为这样也是可以往后接东西的

然后考虑统计答案,因为这个\(\mathrm{kmp}\)可能会跳过一些匹配,所以我们在暴跳\(next\)的过程中顺便统计答案,就是如果\(next\)后面的字符类型和当前相同,答案要加上 当前字符没算过答案的部分 到那两个字符长度最小值的这个区间 在\(next\)后面的字符中的到开始位置的长度之和(请感性理解)

然后有撤销操作,直接上个可持久化我们发现可以把所有操作的串建一个trie树,然后在上面dfs做,然后撤销,就可以不用可持久化了qwq

我不可持久化啦!JOJO!

这样就能获得100分的好成绩(误

然而\(\mathrm{kmp}\)是均摊\(O(n)\)的,出题人想卡你还是可以卡的.所以我们考虑一种叫\(\mathrm{kmp}\)自动机的东西,就是每个状态的后继状态表示这个状态后加一个字符,它的\(next\)会指向哪里,每次转移可以直接把\(next\)设成对应的后继,然后在把对应的后继状态指向当前位置.现在因为新的字符集比较大,所以可以使用可持久化线段树维护这个\(\mathrm{kmp}\)自动机.但是因为现在是一步跳到\(next\),所以考虑如何统计答案.我们可以手玩统计答案过程,假设串是\(...bbb...bb...\),然后我们会先算上前两个b在\(bb\)处的长度,然后加上第三个b在\(bbb\)处的长度,如果从前往后看,可以发现这是每次把一个前缀修改成一个元素值更大的等差数列(公差为1),然后答案只要前缀求和,所以这个东西用可持久化线段树,前缀赋值等差数列以及前缀求和来实现

#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<ctime>
#include<queue>
#include<map>
#include<set>
#define LL long long using namespace std;
const int N=1e5+10,M=2e4,mod=998244353;
int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
int q,n,to[N*100],ch[N*100][2],rt[N][26],t1;
void inst(int o1,int o2,int x,int y)
{
int l=1,r=n;
while(l<r)
{
int mid=(l+r)>>1;
if(x<=mid)
{
ch[o1][0]=++t1,ch[o1][1]=ch[o2][1];
o1=ch[o1][0],o2=ch[o2][0];
r=mid;
}
else
{
ch[o1][0]=ch[o2][0],ch[o1][1]=++t1;
o1=ch[o1][1],o2=ch[o2][1];
l=mid+1;
}
}
to[o1]=y;
}
int q1(int o,int x)
{
int l=1,r=n;
while(l<r)
{
int mid=(l+r)>>1;
if(x<=mid) o=ch[o][0],r=mid;
else o=ch[o][1],l=mid+1;
}
return to[o];
}
int s[N*100],tg[N*100],ch2[N*100][2],r2[N][26],t2;
int gsm(LL x,LL y){return (1ll*(x+x+y-1)*y/2)%mod;}
void psdn(int o,int l,int r)
{
if(l<r&&tg[o])
{
int mid=(l+r)>>1;
++t2,ch2[t2][0]=ch2[ch2[o][0]][0],ch2[t2][1]=ch2[ch2[o][0]][1],ch2[o][0]=t2,s[ch2[o][0]]=gsm(tg[ch2[o][0]]=tg[o],mid-l+1);
++t2,ch2[t2][0]=ch2[ch2[o][1]][0],ch2[t2][1]=ch2[ch2[o][1]][1],ch2[o][1]=t2,s[ch2[o][1]]=gsm(tg[ch2[o][1]]=tg[o]+mid-l+1,r-mid);
tg[o]=0;
}
}
void modif(int &o,int l,int r,int ll,int rr,LL x)
{
psdn(o,l,r);
int lc=ch2[o][0],rc=ch2[o][1];
++t2,s[t2]=gsm((x+max(l-ll,0))%mod,min(r,rr)-max(l,ll)+1),o=t2;
if(ll<=l&&r<=rr){tg[o]=(x+max(l-ll,0))%mod;return;}
int mid=(l+r)>>1;
ch2[o][0]=lc;
if(ll<=mid) modif(ch2[o][0],l,mid,ll,rr,x);
ch2[o][1]=rc;
if(rr>mid) modif(ch2[o][1],mid+1,r,ll,rr,x);
s[o]=(s[ch2[o][0]]+s[ch2[o][1]])%mod;
}
int quer(int o,int l,int r,int ll,int rr)
{
if(!o) return 0;
psdn(o,l,r);
if(ll<=l&&r<=rr) return s[o];
int mid=(l+r)>>1,an=0;
if(ll<=mid) an+=quer(ch2[o][0],l,mid,ll,rr);
if(rr>mid) an+=quer(ch2[o][1],mid+1,r,ll,rr);
return an%mod;
}
int an[N],sg[26],p[N],nw,aa[N],t0;
vector<int> e[N];
int sta[N],mx[N][26],tp;
LL len[N];
void wk(int x)
{
sta[++tp]=aa[x];
int xx=aa[x]/M,ln=aa[x]%M;
len[tp]=(len[tp-1]+ln)%mod;
int nxt=0;
if(tp==1) an[x]=gsm(1,ln-1);
else
{
nxt=q1(rt[tp][xx],ln);
an[x]=(an[x]+quer(r2[tp][xx],1,n,1,ln))%mod;
if(!nxt&&sta[1]/M==xx&&sta[1]%M<ln) nxt=1,an[x]=(an[x]+1ll*len[1]*max(ln-mx[tp][xx],0)%mod)%mod;
}
int las=rt[tp][xx];
inst(rt[tp][xx]=++t1,las,ln,tp);
modif(r2[tp][xx],1,n,1,ln,len[tp-1]+1);
mx[tp][xx]=max(mx[tp][xx],ln);
int nn=e[x].size();
for(int i=0;i<nn;++i)
{
an[e[x][i]]=an[x];
memcpy(rt[tp+1],rt[nxt+1],sizeof(int)*26),memcpy(r2[tp+1],r2[nxt+1],sizeof(int)*26),memcpy(mx[tp+1],mx[nxt+1],sizeof(int)*26);
wk(e[x][i]);
}
--tp;
} int main()
{
//awsl
q=rd();
char cc[4];
for(int i=1;i<=q;++i)
{
int op=rd(),x=rd();
if(op==1)
{
n=max(n,x);
scanf("%s",cc);
p[i]=++t0;
e[nw].push_back(t0),aa[t0]=(cc[0]-'a')*M+x,nw=t0;
}
else nw=p[i]=p[x];
}
int nn=e[0].size();
for(int i=0;i<nn;++i)
memset(rt[1],0,sizeof(int)*26),memset(r2[1],0,sizeof(int)*26),memset(mx[1],0,sizeof(int)*26),wk(e[0][i]);
for(int i=1;i<=q;++i) printf("%d\n",an[p[i]]);
return 0;
}
上一篇:python笔记之字符串


下一篇:【转载】TCP/IP协议栈