[「NOI2018」你的名字](https://loj.ac/problem/2720)
题目描述
小A 被选为了\(ION2018\) 的出题人,他精心准备了一道质量十分高的题目,且已经 把除了题目命名以外的工作都做好了。
由于\(ION\) 已经举办了很多届,所以在题目命名上也是有规定的,\(ION\) 命题手册规 定:每年由命题委员会规定一个小写字母字符串,我们称之为那一年的命名串,要求每道题的名字必须是那一年的命名串的一个非空连续子串,且不能和前一年的任何一道题目的名字相同。
由于一些特殊的原因,小A 不知道\(ION2017\) 每道题的名字,但是他通过一些特殊 手段得到了\(ION2017\) 的命名串,现在小A 有\(Q\) 次询问:每次给定\(ION2017\) 的命名串和\(ION2018\) 的命名串,求有几种题目的命名,使得这个名字一定满足命题委员会的规定,即是\(ION2018\) 的命名串的一个非空连续子串且一定不会和\(ION2017\) 的任何一道题目的名字相同。
由于一些特殊原因,所有询问给出的\(ION2017\) 的命名串都是某个串的连续子串, 详细可见输入格式。
输入格式:
第一行一个字符串\(S\) ,之后询问给出的\(ION2017\) 的命名串都是\(S\) 的连续子串。 第二行一个正整数\(Q\),表示询问次数。 接下来\(Q\) 行,每行有一个字符串\(T\) 和两个正整数\(l,r\),表示询问如果\(ION2017\) 的 命名串是\(S[l..r]\),\(ION2018\) 的命名串是\(T\) 的话,有几种命名方式一定满足规定。
输出格式:
输出\(Q\)行,第\(i\) 行一个非负整数表示第\(i\) 个询问的答案。
先放一个乱搞一个做法(后面补了正解) :
首先考虑 \(l = 1, r = |S|\) 的部分分,我在同步赛上的乱搞做法是对 \(S\) 建 \(Sam\) ,每次询问往 \(Sam\) 里面插入询问串
取出新增的后缀节点,暴力在 \(parent\) 树上跳父亲,计算在 \(S\) 中出现的不同子串个数,以及总共的不同子串个数,相减就是答案
因为相同的子串只会被算一次,所以每一个节点的贡献只会被算一次,单次复杂度是向上跳的期望节点数,根据 \(Sam\) 的一些奇奇怪怪的性质,复杂度上限是\(O(n\sqrt{n})\) ,但是在实际情况下根本卡不满,这 \(68pt\) 中最慢的点是 \(0.8s\)
讨论有了 \(l, r\) 的限制的情况,在原先的算法基础上还需要求出每个节点在限制下能表示的最长的公共子串长度为 \(maxlen\)
设 \(r\) 在该节点 \(right\) 集合中的前驱是 \(r'\) ,那么 \(mxlen = r' - l + 1\) ,通过这个重新计算公共子串个数即可
考虑求出这个东西只需要在 \(Sam\) 上大力线段树合并即可,但是无论多么不满乘上 \(log\) 都会 \(Tle\)
考虑进行剪枝,对每一个节点维护 \(mx_u\) 和 \(mn_u\) 表示其 \(right\) 集合中在 \(S\) 串中出现的最靠前和最靠后的位置
如果有 \(r < mn_u\) 或者 \(l > mx_u\) 这个节点就不会有贡献,可以剪掉
观察发现,大部分情况下 \(mxlen > dep_u\) ,此时求前驱没有任何用处,本质上是因为 \(mx_u > r\) 的缘故
所以在此可以再加上一个剪枝,这样线段树的查询只会在很深的几个节点被调用了.
测一下最大的数据惊奇的发现只需要\(2.5s\) ,交上去发现用了一个 \(O(n\sqrt{n}logn)\)的乱搞水过了此题,(震惊!)
/*pragram by mangoyang*/
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int f = 0, ch = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 3500005, MAXN = 1500005;
char s[N]; int rt[N], buf[N], n;
struct SegmentTree{
int lc[MAXN*25], rc[MAXN*25], sz[MAXN*25], size;
inline SegmentTree(){ size = 1; }
inline void ins(int &u, int l, int r, int pos){
if(!u) u = ++size;
if(l == r) return (void) (sz[u]++);
int mid = l + r >> 1;
if(pos <= mid) ins(lc[u], l, mid, pos);
else ins(rc[u], mid + 1, r, pos);
sz[u] = sz[lc[u]] + sz[rc[u]];
}
inline int merge(int x, int y, int l, int r){
if(!x || !y) return x + y;
int mid = l + r >> 1, o = ++size;
if(l == r) sz[o] = sz[x] + sz[y];
else{
lc[o] = merge(lc[x], lc[y], l, mid);
rc[o] = merge(rc[x], rc[y], mid + 1, r);
sz[o] = sz[lc[o]] + sz[rc[o]];
}
return o;
}
inline int query(int u, int l, int r, int pos){
if(!sz[u]) return 0;
if(l == r) return l;
int mid = l + r >> 1;
if(pos <= mid) return query(lc[u], l, mid, pos);
int rans = query(rc[u], mid + 1, r, pos);
return rans ? rans : query(lc[u], l, mid, pos);
}
}Seg;
struct SuffixAutomaton{
vector<int> g[N], v; ll dep[N];
int ch[N][26], fa[N], vis[N], mx[N], mn[N], tail, size;
inline SuffixAutomaton(){ size = tail = 1, rt[1] = 1; }
inline int newnode(int x){ return dep[++size] = x, size; }
inline void ins(int c, int ff, int pos){
int p = tail, np = newnode(dep[p] + 1);
if(ff) v.push_back(np); else{
Seg.ins(rt[np], 1, n, pos);
mx[np] = mn[np] = pos;
}
for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = np;
if(!p) return (void) (fa[np] = 1, tail = np);
int q = ch[p][c];
if(dep[q] == dep[p] + 1) fa[np] = q;
else{
int nq = newnode(dep[p] + 1);
fa[nq] = fa[q], fa[np] = fa[q] = nq;
if(ff) rt[nq] = rt[q], mx[nq] = mx[q], mn[nq] = mn[q];
for(int i = 0; i < 26; i++) ch[nq][i] = ch[q][i];
for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = nq;
}tail = np;
}
inline void addedge(){
for(int i = 2; i <= size; i++) g[fa[i]].push_back(i);
}
inline void dfs(int u){
for(int i = 0; i < g[u].size(); i++){
int v = g[u][i];
dfs(v), rt[u] = Seg.merge(rt[u], rt[v], 1, n);
mx[u] = Max(mx[u], mx[v]), mn[u] = Min(mn[u], mn[v]);
}
}
inline void prepare(char *s){
for(int i = 0; i < n; i++) ins(s[i] - 'a', 0, i + 1);
addedge(), dfs(1);
}
inline ll calc(char *s, int l, int r){
tail = 1; v.clear();
ll len = strlen(s), ans = 0, all = 0;
for(int i = 0; i < len; i++) ins(s[i] - 'a', 1, 0);
for(int i = 0; i < v.size(); i++){
int u = v[i];
for(int p = u; p > 1; p = fa[p]) {
if(vis[p]) break; int OK = 0;
all += dep[p] - dep[fa[p]], vis[p] = 1;
if(rt[p]){
if((l == 1 && r == n) || OK)
{ ans += dep[p] - dep[fa[p]]; continue; }
if(mx[p] < l || mn[p] > r) continue;
int mxlen = mx[p] <= r ? mx[p] - l + 1 : Seg.query(rt[p], 1, n, r) - l + 1;
if(mxlen > dep[fa[p]])
ans += Min(dep[p], mxlen) - dep[fa[p]];
if(mxlen >= dep[p]) OK = 1;
}
}
}
for(int i = 0; i < v.size(); i++){
int u = v[i];
for(int p = u; p > 1; p = fa[p]){
if(!vis[p]) break; vis[p] = 0;
}
}
return all - ans;
}
}van;
int main(){
scanf("%s", s); n = strlen(s);
int Q; read(Q), van.prepare(s);
while(Q--){
int l, r;
scanf("%s", s), read(l), read(r);
printf("%lld\n", van.calc(s, l, r));
}
}
正解
补集转换一步,问题变成求 \(T\) 与 \(S[l_i:r_i]\) 的本质不同的公共子串数,考虑让 \(T\) 在 \(S\) 的 \(sam\) 上匹配,双指针找出每一个前缀能在 \(S[l_i:r_i]\) 中能匹配上的后缀长度 \(len[i]\),然后在 \(T\) 的 \(sam\) 上统计答案,对于每一个节点随便找一个出现的前缀,拿这个前缀的 \([0,len[i]]\) 和其所能表示的字符串长度区间取交集即可。
求 \(len[i]\) 可以先找到第一个能接收当前字符 \(c\) 的节点,然后不断删去首字母,直到能在 \([l_i:r_i]\) 放下,也就是找到当前匹配节点的 \(right\) 集合,判断一段区间内是否有元素,这个用随便维护一下就好了,线段树合并蛮好写的。
/*program by mangoyang*/
#include <bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 2000005;
char s[N];
int res[N], n, q;
namespace Seg{
#define mid ((l + r) >> 1)
int lc[N*25], rc[N*25], sz[N*25], size;
inline void ins(int &u, int l, int r, int pos){
if(!u) u = ++size;
if(l == r) return (void) (sz[u]++);
if(pos <= mid) ins(lc[u], l, mid, pos);
else ins(rc[u], mid + 1, r, pos);
sz[u] = sz[lc[u]] + sz[rc[u]];
}
inline int merge(int x, int y, int l, int r){
if(!x || !y) return x + y;
int o = ++size;
if(l == r) sz[o] = sz[x] + sz[y];
else{
lc[o] = merge(lc[x], lc[y], l, mid);
rc[o] = merge(rc[x], rc[y], mid + 1, r);
sz[o] = sz[lc[o]] + sz[rc[o]];
}
return o;
}
inline int query(int u, int l, int r, int L, int R){
if(l >= L && r <= R) return sz[u];
int res = 0;
if(L <= mid) res += query(lc[u], l, mid, L, R);
if(mid < R) res += query(rc[u], mid + 1, r, L, R);
return res;
}
#undef mid
}
vector<int> vec[N];
namespace SAM1{
vector<int> g[N];
int ch[N][26], rt[N], fa[N], len[N], size = 1, tail = 1;
inline int newnode(int x){ return len[++size] = x, size; }
inline void ins(int c, int x){
int p = tail, np = newnode(len[p] + 1);
Seg::ins(rt[np], 1, n, x);
for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = np;
if(!p) return (void) (fa[np] = 1, tail = np);
int q = ch[p][c];
if(len[q] == len[p] + 1) fa[np] = q;
else{
int nq = newnode(len[p] + 1);
fa[nq] = fa[q], fa[q] = fa[np] = nq;
for(int i = 0; i < 26; i++) ch[nq][i] = ch[q][i];
for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = nq;
}tail = np;
}
inline void addedge(){
for(int i = 2; i <= size; i++) g[fa[i]].push_back(i);
}
inline void dfs(int u){
for(int i = 0; i < (int) g[u].size(); i++)
dfs(g[u][i]), rt[u] = Seg::merge(rt[u], rt[g[u][i]], 1, n);
}
inline void solve(char *s, int L, int R){
int lenth = strlen(s + 1);
for(int i = 1, p = 1, now = 0; i <= lenth; i++){
int c = s[i] - 'a';
while(!ch[p][c] && p) p = fa[p], now = len[p];
if(!p){ p = 1, now = 0; continue; };
p = ch[p][c], now++;
while(p > 1){
if(Seg::query(rt[p], 1, n, L + now - 1, R)) break;
if(--now == len[fa[p]]) p = fa[p];
}
if(p == 1) continue;
for(int j = 0; j < (int) vec[i].size(); j++)
res[vec[i][j]] = max(res[vec[i][j]], now);
}
}
}
namespace SAM2{
int fa[N], len[N], ch[N][26], size, tail;
inline void Clear(){
for(int i = 1; i <= size; i++){
fa[i] = len[i] = res[i] = 0;
memset(ch[i], 0, sizeof(ch[i]));
}
size = tail = 1;
}
inline int newnode(int x){ return len[++size] = x, size; }
inline void ins(int c, int x){
int p = tail, np = newnode(len[p] + 1);
vec[x].push_back(np);
for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = np;
if(!p) return (void) (fa[np] = 1, tail = np);
int q = ch[p][c];
if(len[q] == len[p] + 1) fa[np] = q;
else{
int nq = newnode(len[p] + 1);
vec[x].push_back(nq);
fa[nq] = fa[q], fa[q] = fa[np] = nq;
for(int i = 0; i < 26; i++) ch[nq][i] = ch[q][i];
for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = nq;
}tail = np;
}
inline ll solve(){
ll ans1 = 0, ans2 = 0;
for(int i = 1; i <= size; i++){
if(res[i] > len[fa[i]])
ans2 += 1ll * min(res[i], len[i]) - len[fa[i]];
ans1 += 1ll * len[i] - len[fa[i]];
}
return ans1 - ans2;
}
}
int main(){
scanf("%s", s + 1), n = strlen(s + 1);
for(int i = 1; i <= n; i++) SAM1::ins(s[i] - 'a', i);
SAM1::addedge(), SAM1::dfs(1);
read(q); int L, R;
while(q--){
scanf("%s", s + 1); int m = strlen(s + 1);
read(L), read(R);
for(int i = 1; i <= m; i++) vec[i].clear();
SAM2::Clear();
for(int i = 1; i <= m; i++) SAM2::ins(s[i] - 'a', i);
SAM1::solve(s, L, R);
printf("%lld\n", SAM2::solve());
}
return 0;
}