题目描述
给一个字符集大小为26的Trie树,节点数目为n,有Q次询问,每次问一个节点对应的字符串在另一个节点对应的字符串中作为子串出现了多少次
输入格式:
输入的第一行包含一个字符串,按阿狸的输入顺序给出所有阿狸输入的字符。
第二行包含一个整数m,表示询问个数。
接下来m行描述所有由小键盘输入的询问。其中第i行包含两个整数x, y,表示第i个询问为(x, y)。
输出格式:
输出m行,其中第i行包含一个整数,表示第i个询问的答案。
数据范围
对于100%的数据,n<=100000,m<=100000,第一行总长度<=100000。
题解
子串是前缀的后缀,所以如果y在x中出现,那么从x的前缀跳fail可以跳到y,记录有多少个点可以到y即为答案
优化:能跳到y,那么在fail树上就是y的子孙,对fail树求dfs序。
在x存下询问,别的字符串在x中出现几次。对Trie树dfs,进入一个节点就用树状数组在对应的dfs序+1,退出-1。遇到询问,就查询y的子树。
这样就可以查出这条路径上有多少点可以到y,而且可以遍历出所有路径。
在建Trie树时,不能遇到P才从头插入,要记录当前节点,有B就跳fa,P就记录字符串。
对于fail树的建立不能打乱原Trie树,所以要建两颗
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #include<vector> using namespace std; const int maxn=400005; int n=-1,m,num,cnt; int mp[maxn],go[maxn][26],fa[maxn]; int gogo[maxn][26],fail[maxn];//tire图 int head[maxn],ans[maxn]; int in[maxn],size[maxn]; int tr[maxn<<1]; char t[maxn]; vector<pair<int,int> >que[maxn]; struct edge{ int y,next; }e[maxn]; void addedge(int x,int y){ e[++cnt].y=y; e[cnt].next=head[x]; head[x]=cnt; } int q[maxn]; void get_fail(){ cnt=0; int h=0,tail=0; for(int i=0;i<26;i++) if(gogo[0][i]){q[tail++]=gogo[0][i];addedge(0,gogo[0][i]);} while(h<tail){ int x=q[h++]; for(int i=0;i<26;i++){ if(!gogo[x][i]) gogo[x][i]=gogo[fail[x]][i]; else{ fail[gogo[x][i]]=gogo[fail[x]][i]; addedge(gogo[fail[x]][i],gogo[x][i]); q[tail++]=gogo[x][i]; } } } } void dfs(int u){ in[u]=++cnt;size[u]=1; for(int i=head[u];i;i=e[i].next){ dfs(e[i].y); size[u]+=size[e[i].y]; } } void add(int x,int val){for(;x<=cnt;x+=x&-x) tr[x]+=val;} int sum(int x){ int ret=0; for(;x;x-=x&-x) ret+=tr[x]; return ret; } void DFS(int now){ if(que[now].size()){ for(unsigned int i=0;i<que[now].size();i++){ int x=que[now][i].first,y=que[now][i].second; ans[y]=sum(in[x]+size[x]-1)-sum(in[x]-1); } } for(int i=0;i<26;i++) if(go[now][i]){ add(in[go[now][i]],1); DFS(go[now][i]); add(in[go[now][i]],-1); } } int main(){ scanf("%s%d",t,&m); int len=strlen(t); int now=0; for(int i=0;i<len;i++){ if(t[i]=='B') now=fa[now]; else if(t[i]=='P') mp[++cnt]=num; else { int c=t[i]-'a'; if(!go[now][c]) gogo[now][c]=go[now][c]=++num,fa[num]=now; now=go[now][c]; } } for(int i=1;i<=m;i++){ int x,y; scanf("%d%d",&x,&y); que[mp[y]].push_back(make_pair(mp[x],i)); } get_fail(); cnt=0; dfs(0); DFS(0); for(int i=1;i<=m;i++) printf("%d\n",ans[i]); }