题目描述
(我并不想告诉你题目名字是什么鬼)
有一个长度为n的仅包含小写字母的字符串S,下标范围为[1,n].
现在有若干组询问,对于每一个询问,我们给出若干个后缀(以其在S中出现的起始位置来表示),求这些后缀两两之间的LCP(LongestCommonPrefix)的长度之和.一对后缀之间的LCP长度仅统计一遍.
输入
第一行两个正整数n,m,分别表示S的长度以及询问的次数.
接下来一行有一个字符串S.
接下来有m组询问,对于每一组询问,均按照以下格式在一行内给出:
首先是一个整数t,表示共有多少个后缀.接下来t个整数分别表示t个后缀在字符串S中的出现位置.
输出
样例输入
7 3
popoqqq
1 4
2 3 5
4 1 2 5 6
样例输出
0
0
2
题解
后缀数组+倍增RMQ+单调栈
首先预处理出sa和height数组。
然后对于每组询问,将要求的后缀去重后按照rank从小到大排序。
由于我们有:LCP(a,c)=min(LCP(a,b),LCP(b,c)),其中rank[a]<rank[b]<rank[c]
所以我们只需要知道相邻两个要求的后缀之间的LCP,即可推出任意两个后缀的LCP。
这里求LCP的方式是倍增RMQ,所以我偷改了height的定义:height[i][j]表示排名为i-2^j的后缀与排名为i的后缀的LCP。
这样转化成了一个新的问题:给你n个数,求其每个子区间中最小值的和。
考虑对答案的贡献:ai对答案的贡献是满足l∈[lpos,i],r∈[i,rpos]的所有区间[l,r],也即ai*(i-lpos+1)*(rpos-i+1),其中lpos是i左侧最后一个大于i的,rpos是i右侧最后一个大于等于i的。
(左右包含等号的情况不同是为了处理相同的数,防止重复或漏算)
可以用一个单调栈来在线性时间内求出i-lpos+1和rpos-i+1,具体方法见代码。
最后的最后,需要把用于去重的数组vis清零,注意不能用memset。
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 500010
#define mod 23333333333333333ll
using namespace std;
int n , m , sa[N] , r[N] , ws[N] , wa[N] , wb[N] , wv[N] , rank[N] , height[N][21] , log[N] , num[N * 6] , vis[N] , pos[N] , val[N] , sta[N] , top , lp[N] , rp[N];
char str[N];
void da()
{
int i , j , p , *x = wa , *y = wb;
for(i = 0 ; i < m ; i ++ ) ws[i] = 0;
for(i = 0 ; i < n ; i ++ ) ws[x[i] = r[i]] ++ ;
for(i = 1 ; i < m ; i ++ ) ws[i] += ws[i - 1];
for(i = n - 1 ; i >= 0 ; i -- ) sa[--ws[x[i]]] = i;
for(p = j = 1 ; p < n ; j <<= 1 , m = p)
{
for(p = 0 , i = n - j ; i < n ; i ++ ) y[p ++ ] = i;
for(i = 0 ; i < n ; i ++ ) if(sa[i] - j >= 0) y[p ++ ] = sa[i] - j;
for(i = 0 ; i < n ; i ++ ) wv[i] = x[y[i]];
for(i = 0 ; i < m ; i ++ ) ws[i] = 0;
for(i = 0 ; i < n ; i ++ ) ws[wv[i]] ++ ;
for(i = 1 ; i < m ; i ++ ) ws[i] += ws[i - 1];
for(i = n - 1 ; i >= 0 ; i -- ) sa[--ws[wv[i]]] = y[i];
for(swap(x , y) , x[sa[0]] = 0 , p = i = 1 ; i < n ; i ++ )
{
if(y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + j] == y[sa[i] + j]) x[sa[i]] = p - 1;
else x[sa[i]] = p ++ ;
}
}
for(i = 1 ; i < n ; i ++ ) rank[sa[i]] = i;
for(p = i = 0 ; i < n - 1 ; height[rank[i ++ ]][0] = p)
for(p ? p -- : 0 , j = sa[rank[i] - 1] ; r[i + p] == r[j + p] ; p ++ );
}
int query(int x , int y)
{
x ++ ;
int k = log[y - x + 1];
return min(height[x + (1 << k) - 1][k] , height[y][k]);
}
bool cmp(int a , int b)
{
return rank[a] < rank[b];
}
int main()
{
int i , j , k , cnt , tot;
long long ans;
scanf("%d%d%s" , &n , &k , str);
for(i = 0 ; i < n ; i ++ ) r[i] = str[i] - 'a' + 1;
n ++ , m = 28 , da() , n -- ;
for(i = 2 ; i <= n ; i ++ ) log[i] = log[i >> 1] + 1;
for(i = 1 ; i <= log[n] ; i ++ )
for(j = (1 << i) ; j <= n ; j ++ )
height[j][i] = min(height[j][i - 1] , height[j - (1 << (i - 1))][i - 1]);
while(k -- )
{
scanf("%d" , &cnt);
tot = 0 , ans = 0;
for(i = 1 ; i <= cnt ; i ++ )
{
scanf("%d" , &num[i]) , num[i] -- ;
if(!vis[num[i]]) vis[num[i]] = 1 , pos[++tot] = num[i];
}
sort(pos + 1 , pos + tot + 1 , cmp);
for(i = 1 ; i < tot ; i ++ )
val[i] = query(rank[pos[i]] , rank[pos[i + 1]]);
sta[0] = top = 0;
for(i = 1 ; i < tot ; i ++ )
{
while(top && val[sta[top]] > val[i]) top -- ;
lp[i] = i - sta[top] , sta[++top] = i;
}
sta[0] = tot , top = 0;
for(i = tot - 1 ; i ; i -- )
{
while(top && val[sta[top]] >= val[i]) top -- ;
rp[i] = sta[top] - i , sta[++top] = i;
}
for(i = 1 ; i < tot ; i ++ ) ans = (ans + (long long)lp[i] * rp[i] * val[i]) % mod;
printf("%lld\n" , ans);
for(i = 1 ; i <= cnt ; i ++ ) vis[num[i]] = 0;
}
return 0;
}