洛谷P2336 喵星球上的点名

洛谷P2336 喵星球上的点名

洛谷P2336 喵星球上的点名

解:SAM + 线段树合并 + DFS序。

姓和名之间插入特殊字符,转化为下题:

给定串集合S,T,问S中每个串包含了T中的几个串?T中每个串被多少个S中的串包含?

解:对S建广义SAM,并线段树合并维护每个节点有多少串。

T中每个串在S的sam上跑,如果没能跑完就被包含0次。否则答案就是到达的节点上的串数。第二问解决。

标记T中每个串最后到达的节点。S中每个串跑S的sam会得到若干个点。统计这些点到根路径的并集上的标记个数即可。

按DFS序排序,加上每个节点到根路径的贡献,减去相邻节点lca到根路径的贡献。第一问解决。

 #include <bits/stdc++.h>

 const int N = , M = ;

 struct Edge {
int nex, v;
}edge[N]; int tp; std::map<int, int> tr[N];
int fail[N], len[N], tot = , e[N], n, m, siz[N], ed[N],
stk[N], top, num2, pos2[N], ST[N << ][], pw[N << ], d[N];
int ls[M], rs[M], num, sum[M], rt[N];
std::vector<int> str[N]; inline void add(int x, int y) {
tp++;
edge[tp].v = y;
edge[tp].nex = e[x];
e[x] = tp;
return;
} inline bool cmp(const int &a, const int &b) {
return pos2[a] < pos2[b];
} void insert(int p, int l, int r, int &o) {
if(!o) o = ++num;
if(l == r) {
sum[o] = ;
return;
}
int mid = (l + r) >> ;
if(p <= mid) insert(p, l, mid, ls[o]);
else insert(p, mid + , r, rs[o]);
sum[o] = sum[ls[o]] + sum[rs[o]];
return;
} int merge(int x, int y) {
if(!x || !y) return x | y;
int o = ++num;
ls[o] = merge(ls[x], ls[y]);
rs[o] = merge(rs[x], rs[y]);
if(!ls[o] && !rs[o]) sum[o] = ;
else sum[o] = sum[ls[o]] + sum[rs[o]];
return o;
} int ask(int L, int R, int l, int r, int o) {
if(!o) return ;
if(L <= l && r <= R) return sum[o];
int mid = (l + r) >> , ans = ;
if(L <= mid) ans += ask(L, R, l, mid, ls[o]);
if(mid < R) ans += ask(L, R, mid + , r, rs[o]);
return ans;
} void DFS_1(int x) {
pos2[x] = ++num2;
ST[num2][] = x;
for(int i = e[x]; i; i = edge[i].nex) {
int y = edge[i].v;
d[y] = d[x] + ;
DFS_1(y);
ST[++num2][] = x;
rt[x] = merge(rt[x], rt[y]);
}
return;
} inline void prework() {
for(int i = ; i <= num2; i++) pw[i] = pw[i >> ] + ;
for(int j = ; j <= pw[num2]; j++) {
for(int i = ; i + (j << ) - <= num2; i++) {
if(d[ST[i][j - ]] < d[ST[i + ( << (j - ))][j - ]])
ST[i][j] = ST[i][j - ];
else
ST[i][j] = ST[i + ( << (j - ))][j - ];
}
}
return;
} inline int lca(int x, int y) {
x = pos2[x];
y = pos2[y];
if(x > y) std::swap(x, y);
int t = pw[y - x + ];
if(d[ST[x][t]] < d[ST[y - ( << t) + ][t]])
return ST[x][t];
else
return ST[y - ( << t) + ][t];
} void DFS_2(int x) {
siz[x] += ed[x];
for(int i = e[x]; i; i = edge[i].nex) {
int y = edge[i].v;
siz[y] = siz[x];
DFS_2(y);
}
return;
} inline int split(int p, int f) {
int Q = tr[p][f], nQ = ++tot;
len[nQ] = len[p] + ;
fail[nQ] = fail[Q];
fail[Q] = nQ;
//memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
tr[nQ] = tr[Q];
while(tr[p][f] == Q) {
tr[p][f] = nQ;
p = fail[p];
}
return nQ;
} inline int insert(int p, int f, int id) {
int np;
if(tr[p].count(f)) {
int Q = tr[p][f];
if(len[Q] == len[p] + ) {
np = Q;
}
else {
np = split(p, f);
}
insert(id, , n, rt[np]);
return np;
}
np = ++tot;
len[np] = len[p] + ;
while(p && !tr[p].count(f)) {
tr[p][f] = np;
p = fail[p];
}
if(!p) {
fail[np] = ;
}
else {
int Q = tr[p][f];
if(len[Q] == len[p] + ) {
fail[np] = Q;
}
else {
fail[np] = split(p, f);
}
}
insert(id, , n, rt[np]);
return np;
} void out(int l, int r, int o) {
if(!o) return;
if(l == r) {
printf("%d ", r);
return;
}
int mid = (l + r) >> ;
out(l, mid, ls[o]);
out(mid + , r, rs[o]);
return;
} inline void clear() {
for(int i = ; i <= tot; i++) {
e[i] = len[i] = fail[i] = rt[i] = ;
tr[i].clear();
}
for(int i = ; i <= num; i++) {
ls[i] = rs[i] = sum[i] = ;
}
tp = num = ;
tot = ;
return;
} int main() {
scanf("%d%d", &n, &m);
for(int i = ; i <= n; i++) {
int k, x, p = ;
scanf("%d", &k);
for(int j = ; j <= k; j++) {
scanf("%d", &x);
str[i].push_back(x);
p = insert(p, x, i);
}
str[i].push_back(-);
p = insert(p, -, i);
scanf("%d", &k);
for(int j = ; j <= k; j++) {
scanf("%d", &x);
str[i].push_back(x);
p = insert(p, x, i);
}
}
/// build
for(int i = ; i <= tot; i++) {
//printf("add %d %d \n", fail[i], i);
add(fail[i], i);
}
DFS_1(); for(int i = ; i <= m; i++) {
int k, x, p = , fd = , ans = ;
scanf("%d", &k);
for(int j = ; j <= k; j++) {
scanf("%d", &x);
if(!tr[p].count(x)) fd = ;
else p = tr[p][x];
}
if(!fd) {
ans = sum[rt[p]];
ed[p]++;
}
printf("%d\n", ans);
} DFS_2();
prework(); for(int i = ; i <= n; i++) {
int p = ; top = ;
for(int j = ; j < (int)str[i].size(); j++) {
int x = str[i][j];
p = tr[p][x];
stk[++top] = p;
}
std::sort(stk + , stk + top + ,cmp);
top = std::unique(stk + , stk + top + ) - stk - ;
int ans = ;
for(int j = ; j <= top; j++) {
ans += siz[stk[j]];
if(j < top) ans -= siz[lca(stk[j], stk[j + ])];
}
printf("%d ", ans);
} return ;
}

AC代码

AC自动机解法

上一篇:(实用篇)浅谈PHP拦截器之__set()与__get()的理解与使用方法


下一篇:poj3249