题面
思路
-
把n个字符串用string存,按字典序排序,建一棵trie,再倒序,排序,再建一棵trie
-
我们发现,对于trie树上的每一个节点,包含它的字符串的序号是连续的,区间左端点是包括这个节点的字典序最小的字符串的序号,区间右端点是包括这个节点的字典序最大的字符串的序号
-
设正序排的区间[l1,r1],倒序排的区间[l2,r2],每个字符串的正序序号i1,倒序序号i2
-
对于每个询问,满足条件的字符串l1<=i1<=r1且l2<=i2<=r2,但i1和i2是一一对应的,所以设\(id_i\)是每个倒序序号是i的字符串的正序序号,所以有\(l_1\)<=\(id_i\)<=\(r_1\),\(l_2\)<=i<=\(r_2\)
5.于是这就转化成了 二维数点 问题,以i为x轴,\(id_i\)为y轴,每个字符串是一个要数的点,每个询问的矩形以(l2,l1),(l2,r1),(r2,l1),(r2,r1)为四个顶点(PS:二维数点详见(https://www.cnblogs.com/Ritalc/p/14743019.html))
code
#include<bits/stdc++.h>
#define int long long
#define N 100010
#define M 2000010
#define re register
using namespace std;
int n,m,ans[N<<2],tot[5],t[2][M][4],l[2][M],r[2][M],cnt,book[N],tree[9*N];
char aa[M];
string ss;
template <class T> inline void read(T &x)
{
x=0;int g=1;char s=getchar();
for (;s<'0'||s>'9';s=getchar()) if (s=='-') g=-1;
for (;s>='0'&&s<='9';s=getchar()) x=(x<<1)+(x<<3)+(s^48);
x*=g;
}
struct str
{
int id;string s;
}s1[N];
bool cmp1(str x,str y)
{
return x.s<y.s;
}
void insert(string s,int id,int q)
{
int p=1,len=s.size();
for (int i=0;i<len;i++)
{
int ch;
if (s[i]=='A') ch=0;
else if (s[i]=='C') ch=1;
else if (s[i]=='G') ch=2;
else ch=3;
if (!t[q][p][ch]) t[q][p][ch]=++tot[q],l[q][tot[q]]=id;
p=t[q][p][ch];
r[q][p]=id;
}
}
int query(string s,int q)
{
int len=s.size(),p=1;
for (int i=0;i<len;i++)
{
int ch;
if (s[i]=='A') ch=0;
else if (s[i]=='C') ch=1;
else if (s[i]=='G') ch=2;
else ch=3;
p=t[q][p][ch];
if (p==0) return 0;
}
return p;
}
struct node
{
int x,y,t,id;
}e[9*N];
bool cmp2(node x,node y)
{
if (x.x==y.x)
{
if (x.y==y.y) return x.t<y.t;
return x.y<y.y;
}
return x.x<y.x;
}
void add(int x)
{
for (;x<=cnt;x+=(x&(-x))) tree[x]++;
}
int ask(int x)
{
int tmp=0;
for (;x;x-=(x&(-x))) tmp+=tree[x];return tmp;
}
signed main()
{
re int i,j,x,y,z,op;
read(n);read(m);tot[0]=tot[1]=1;
for (i=1;i<=n;i++)
{
scanf("%s",aa);
s1[i].s=aa;
}
sort(s1+1,s1+n+1,cmp1);
for (i=1;i<=n;i++) s1[i].id=i;
for (i=1;i<=n;i++) insert(s1[i].s,s1[i].id,0);
for (i=1;i<=n;i++) reverse(s1[i].s.begin(),s1[i].s.end());
sort(s1+1,s1+n+1,cmp1);
for (i=1;i<=n;i++) insert(s1[i].s,i,1);
for (i=1;i<=n;i++)
{
++cnt;e[cnt].x=i;e[cnt].y=s1[i].id;e[cnt].t=0;e[cnt].id=i;
}
for (i=1;i<=m;i++)
{
scanf("%s",aa);ss=aa;
int tmp1=query(ss,0);
scanf("%s",aa);ss=aa;reverse(ss.begin(),ss.end());
int tmp2=query(ss,1);
if (tmp1==0||tmp2==0) {book[i]=1;continue;}//一个容易挂掉的小细节:不continue会在add中产生-1的数组下标
int l1=l[0][tmp1],r1=r[0][tmp1];
int l2=l[1][tmp2],r2=r[1][tmp2];
++cnt;e[cnt].x=l2-1;e[cnt].y=l1-1;e[cnt].t=1;e[cnt].id=i;
++cnt;e[cnt].x=r2;e[cnt].y=r1;e[cnt].t=1;e[cnt].id=i+m;
++cnt;e[cnt].x=l2-1;e[cnt].y=r1;e[cnt].t=1;e[cnt].id=i+2*m;
++cnt;e[cnt].x=r2;e[cnt].y=l1-1;e[cnt].t=1;e[cnt].id=i+3*m;
}
sort(e+1,e+cnt+1,cmp2);
for (i=1;i<=cnt;i++)
{
if (e[i].t==0) add(e[i].y);
else ans[e[i].id]=ask(e[i].y);
}
for (i=1;i<=m;i++)
{
if (book[i]) printf("0\n");
else
{
int tmp=ans[i]+ans[i+m]-ans[i+2*m]-ans[i+3*m];
printf("%lld\n",tmp);
}
}
return 0;
}