vjudge传送门
题面:\(Q\)次询问,每次求子串\([S_l, S_r]\)第\(k\)次出现的位置。
首先,对于原串中的每一个子串,都能在SAM的后缀链接树上找到对应的节点。
那么如果我们知道这个节点的endpos集合,就能找到第\(k\)次出现的位置了。
所以接下来就要解决这两个问题:
1.快速确定子串在后缀链接树上的位置。
2.求出某个节点的endpos结合。
1.我们能轻松求出的是,位置\(S_r\)所在的节点\(u\)。那么\([S_l,S_r]\)所代表的子串只可能是\(u\)或\(u\)的祖先,准确来说,是深度最小的节点\(v\),满足\(len[v] \geqslant S_r - S_l + 1\)。这个用倍增就能求出来。
2.怎么求endpos个数?因为一个节点的endpos集合由他的所有子节点合并而来,这启发我们可以自底向上的合并求出所有节点的endpos。但如果暴力合并或者启发式合并,时间空间都不理想,所以用线段树合并就行了。代码中写的线段树合并没有垃圾回收,不过也过了。
求出endpos集合后,在线段树上二分就能找到第\(k\)次出现的位置了。
时间复杂度\(O(nlogn)\)(线段树合并复杂度为\(O(nlogn)\),倍增复杂度为\(O(nlog2n)\)).
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int maxn = 1e5 + 5;
const int maxs = 27;
const int maxt = 4e7 + 5;
const int N = 17;
In ll read()
{
ll ans = 0;
char ch = getchar(), las = ' ';
while(!isdigit(ch)) las = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(las == '-') ans = -ans;
return ans;
}
In void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int n, Q, ans[maxn];
char s[maxn];
struct Node{int K, id, len;};
vector<Node> v[maxn << 1];
struct Tree
{
int ls, rs, sum;
}t[maxt];
int root[maxn << 1], cnt = 0;
In int New()
{
++cnt;
t[cnt].ls = t[cnt].rs = t[cnt].sum = 0;
return cnt;
}
In void update(int& now, int l, int r, int d)
{
if(!now) now = New();
t[now].sum++;
if(l == r) return;
int mid = (l + r) >> 1;
if(d <= mid) update(t[now].ls, l, mid, d);
else update(t[now].rs, mid + 1, r, d);
}
In int merge(int x, int y, int l, int r)
{
if(!x || !y) return x | y;
if(l == r) {t[x].sum += t[y].sum; return x;}
int mid = (l + r) >> 1, z = ++cnt;
t[z].ls = merge(t[x].ls, t[y].ls, l, mid);
t[z].rs = merge(t[x].rs, t[y].rs, mid + 1, r);
t[z].sum = t[t[z].ls].sum + t[t[z].rs].sum;
return z;
}
In int query(int now, int l, int r, int d)
{
if(t[now].sum < d) return -1;
if(l == r) return l;
int mid = (l + r) >> 1, Sum = t[t[now].ls].sum;
if(d <= Sum) return query(t[now].ls, l, mid, d);
else return query(t[now].rs, mid + 1, r, d - Sum);
}
In void _Print(int now, int l, int r) //调试用
{
// printf("_______%d %d %d\n", now, l, r);
if(!t[now].sum) return;
if(l == r) {write(l), space; return;}
int mid = (l + r) >> 1;
_Print(t[now].ls, l, mid), _Print(t[now].rs, mid + 1, r);
}
int cur[maxn];
struct Sam
{
int tra[maxn << 1][maxs], link[maxn << 1], len[maxn << 1], cnt, las;
In void init()
{
link[cnt = las = 0] = -1;
Mem(tra[0], 0);
Mem(buc, 0), Mem(pos, 0);
}
In void insert(int c)
{
int now = ++cnt, p = las; Mem(tra[now], 0);
len[now] = len[las] + 1;
while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
if(p == -1) link[now] = 0;
else
{
int q = tra[p][c];
if(len[q] == len[p] + 1) link[now] = q;
else
{
int clo = ++cnt; memcpy(tra[clo], tra[q], sizeof(tra[q]));
len[clo] = len[p] + 1;
link[clo] = link[q]; link[q] = link[now] = clo;
while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
}
}
las = now;
}
int buc[maxn << 1], pos[maxn << 1];
In void solve()
{
for(int i = 1; i <= cnt; ++i) buc[len[i]]++;
for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
for(int i = cnt; i >= 0; --i)
{
int now = pos[i], fa = link[pos[i]];
for(auto it : v[now]) ans[it.id] = query(root[now], 0, n - 1, it.K) + 1 - it.len + 1;
root[fa] = merge(root[fa], root[now], 0, n - 1);
}
}
}S;
int fa[N + 2][maxn << 1];
In int solve(int x, int len)
{
for(int i = N; i >= 0; --i)
if(fa[i][x] && S.len[fa[i][x]] >= len) x = fa[i][x];
return x;
}
In void init()
{
cnt = 0, Mem(root, 0); Mem(fa, 0);
for(int i = 0; i <= (n << 1); ++i) v[i].clear();
S.init();
}
int main()
{
int T = read();
while(T--)
{
n = read(), Q = read();
init();
scanf("%s", s);
for(int i = 0; i < n; ++i)
{
S.insert(s[i] - 'a'); cur[i] = S.las;
update(root[cur[i]], 0, n - 1, i);
}
for(int i = 1; i <= cnt; ++i) fa[0][i] = S.link[i];
for(int j = 1; j <= N; ++j)
for(int i = 1; i <= cnt; ++i) fa[j][i] = fa[j - 1][fa[j - 1][i]];
for(int i = 1; i <= Q; ++i)
{
int L = read() - 1, R = read() - 1, K = read();
int p = solve(cur[R], R - L + 1);
v[p].push_back((Node){K, i, R - L + 1});
}
S.solve();
for(int i = 1; i <= Q; ++i) write(ans[i] <= 0 ? -1 : ans[i]), enter;
}
return 0;
}