Code:
#include <cstdio> #include <queue> #include <map> #include <algorithm> #include <cstring> #define N 300002 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int T,cas; struct Seg { #define lson (now<<1|1) #define rson (now<<1) struct Node { int tag; }t[N<<2]; void update(int l,int r,int now,int L,int R,int v) { if(l>=L&&r<=R) { t[now].tag=max(t[now].tag, v); return; } int mid=(l+r)>>1; if(L<=mid) update(l,mid,lson,L,R,v); if(R>mid) update(mid+1,r,rson,L,R,v); } int query(int l,int r,int now,int p,int pre) { pre=max(pre, t[now].tag); if(l==r) return pre; int mid=(l+r)>>1; if(p<=mid) return query(l,mid,lson,p,pre); else return query(mid+1,r,rson,p,pre); } #undef lson #undef rson }seg; struct Node { int f, ch[27]; }t[N]; queue<int>q; char str[N]; int n,tot,tim,edges,w[N],endpos[N],hd[N],nex[N],to[N],dfn[N],size[N]; void addedge(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs(int u) { dfn[u]=++tim,size[u]=1; for(int i=hd[u];i;i=nex[i]) dfs(to[i]), size[u]+=size[to[i]]; } int insert() { int len=strlen(str+1),i,rt=0; for(i=1;i<=len;++i) { if(!t[rt].ch[str[i]-'a']) t[rt].ch[str[i]-'a']=++tot; rt=t[rt].ch[str[i]-'a']; } return rt; } void build() { int i,j; for(i=0;i<27;++i) if(t[0].ch[i]) q.push(t[0].ch[i]); while(!q.empty()) { int u=q.front();q.pop(); for(i=0;i<27;++i) { int p=t[u].ch[i]; if(!p) { t[u].ch[i]=t[t[u].f].ch[i]; continue; } t[p].f=t[t[u].f].ch[i]; q.push(p); } } } void solve() { int i,j; scanf("%d",&n); for(i=1;i<=n;++i) { scanf("%s%d",str+1,&w[i]); if(w[i]>0) endpos[i]=insert(); } build(); for(i=1;i<=tot;++i) addedge(t[i].f,i); dfs(0); int answer=0; for(i=1;i<=n;++i) { if(w[i]<=0) continue; int p=endpos[i],re=0; while(p) re=max(re, seg.query(1,tim,1,dfn[p],0)), p=t[p].f; re+=w[i]; answer=max(answer, re); seg.update(1,tim,1,dfn[endpos[i]],dfn[endpos[i]]+size[endpos[i]]-1,re); } printf("%d\n",answer); memset(hd,0,sizeof(hd)), memset(endpos,0,sizeof(endpos)), memset(t,0,sizeof(t)); tot=tim=edges=0; memset(seg.t,0,sizeof(seg.t)); } int main() { // setIO("input"); scanf("%d",&T); for(cas=1;cas<=T;++cas) solve(); return 0; }