Luogu P7279 光棱碎片
首先可以差分将限制转化为 \((a_{r_1}\oplus a_{r_2})+(r_1-l_1+1)\le k\)。
将 \(\texttt{SAM}\) 建出来后对于每个本质不同子串的 \(\text{endpos}\) 考虑。设点 \(x_1,x_2\) 分别对应原序列中 \(r_1,r_2\) 在 \(\texttt{parent tree}\) 上的位置,设 \(y=\operatorname{lca}(x_1,x_2)\) 那么点对 \(x_1,x_2\) 的贡献为 \(\sum\limits_{i=1}^{\operatorname{len}_y}[(a_{r_1}\oplus a_{r_2})+i\le k]\)。
注意到我们要统计所有的 \(\operatorname{endpos}\) 点对,可以考虑使用 \(\texttt{dsu on tree}\) 优化。于是你要维护一个数据结构,实现以下操作:
- 向当前容器 \(S\) 内加入一个数 \(x\) ;
- 查询 \(\sum\limits_{y\in S}\max\{0,\min\{k-(x\oplus y),d\}\}\),其中 \(k,d\) 为两个定值。
这个查询其实还有点阴间。我们记 \(c_k=\sum\limits_{y\in S}[x\oplus y\le k],g_k=\sum\limits_{y\in S}[x\oplus y\le k](x\oplus y)\),那么就有 \(A=\sum\limits_{x\oplus y\le k}\min\{k-(x\oplus y),d\}=d\cdot c_{k-d}+k\cdot (c_{k}-c_{k-d})-(g_{k}-g_{k-d})\)。于是我们只需考虑如何求出 \(c_k,g_k\)。
建出 \(\texttt{01trie}\)。查询 \(c_k\) 是基操,不用多说;而查询 \(g_k\) 时,只需要在 \(\texttt{trie}\) 的每个结点上拆位维护每一位 \(1\) 的个数即可。
总时间复杂度为 \(\mathcal O(n\log n\log ^2V)\),空间复杂度为 \(\mathcal O(n\log ^2V)\)。由于 \(\texttt{dsu on tree}\) 和 \(\texttt{01trie}\) 的常数都很小,就过了。
参考代码
#include <bits/stdc++.h>
using namespace std;
static constexpr int mod = 998244353;
inline int add(int x, int y) { return x += y - mod, x + (x >> 31 & mod); }
inline int sub(int x, int y) { return x -= y, x + (x >> 31 & mod); }
inline int mul(int x, int y) { return (int64_t)x * y % mod; }
inline void add_eq(int &x, int y) { x += y - mod, x += (x >> 31 & mod); }
inline void sub_eq(int &x, int y) { x -= y, x += (x >> 31 & mod); }
inline void mul_eq(int &x, int y) { x = (int64_t)x * y % mod; }
static constexpr int Maxn = 2e5 + 5, MaxS = 26;
int n, en, head[Maxn], dn, ans;
int wl, wr, w[Maxn];
char str[Maxn];
struct Edge { int to, nxt; } e[Maxn];
void add_edge(int u, int v) { e[++en] = (Edge){v, head[u]}, head[u] = en; }
struct state { int ch[MaxS], link, len; } tr[Maxn];
int last, sn, edp[Maxn], iedp[Maxn];
void extend(int c) {
int p = last, cur = last = ++sn, r;
edp[tr[cur].len = tr[p].len + 1] = cur; iedp[cur] = tr[cur].len;
for (; ~p && !tr[p].ch[c]; p = tr[p].link) tr[p].ch[c] = cur;
if (p == -1) return ; int q = tr[p].ch[c];
if (tr[q].len == tr[p].len + 1) tr[cur].link = q;
else {
tr[r = ++sn].len = tr[p].len + 1;
memcpy(tr[r].ch, tr[q].ch, MaxS << 2);
for (; ~p && tr[p].ch[c] == q; p = tr[p].link) tr[p].ch[c] = r;
tr[r].link = tr[q].link, tr[q].link = tr[cur].link = r;
}
} // extend
namespace trie {
static constexpr int LOG = 17;
struct node {
int ch[2], c, s[LOG];
node() = default;
} tr[Maxn * LOG * 2];
int tn = 1;
inline int newnode(void) {
return tr[++tn] = node(), tn;
} // trie::newnode
void insert(int w) {
int p = 1; tr[p].c++;
for (int k = 0; k < LOG; ++k)
tr[p].s[k] += (w >> k & 1);
for (int i = LOG - 1; i >= 0; --i) {
int dir = w >> i & 1;
if (!tr[p].ch[dir])
tr[p].ch[dir] = newnode();
p = tr[p].ch[dir]; tr[p].c++;
for (int k = 0; k < LOG; ++k)
tr[p].s[k] += (w >> k & 1);
}
} // trie::insert
pair<int, int> ask(int w, int r) {
if (r < 0) return {0, 0};
int p = 1, c = 0, s = 0;
for (int i = LOG - 1; i >= 0 && p; --i) {
int dir = ((w ^ r) >> i & 1) ^ 1;
if (r >> i & 1) {
c += tr[tr[p].ch[dir]].c;
for (int k = 0; k < LOG; ++k) {
int cs = (w >> k & 1)
? tr[tr[p].ch[dir]].c - tr[tr[p].ch[dir]].s[k]
: tr[tr[p].ch[dir]].s[k];
add_eq(s, ((int64_t)cs << k) % mod);
}
}
p = tr[p].ch[dir ^ 1];
}
c += tr[p].c;
for (int k = 0; k < LOG; ++k) {
int cs = (w >> k & 1) ? tr[p].c - tr[p].s[k] : tr[p].s[k];
add_eq(s, ((int64_t)cs << k) % mod);
}
return {c, s};
} // trie::ask
} // namespace trie
inline int ask(int w, int r, int len) {
auto r1 = trie::ask(w, r - len), r2 = trie::ask(w, r);
return add(mul(r1.first, len), sub(mul(r2.first - r1.first, r), sub(r2.second, r1.second)));
} // ask
inline int query(int w, int len) { return sub(ask(w, wr, len), ask(w, wl - 1, len)); }
int sz[Maxn], hson[Maxn], dep[Maxn], dfn[Maxn], idfn[Maxn];
void sack_init(int u, int depth) {
dep[u] = depth; sz[u] = 1, hson[u] = -1; idfn[dfn[u] = ++dn] = u;
for (int i = head[u], v; i; i = e[i].nxt) {
sack_init(v = e[i].to, depth + 1), sz[u] += sz[v];
if (hson[u] == -1 || sz[v] > sz[hson[u]]) hson[u] = v;
}
} // sack_init
void sack(int u, bool keep) {
for (int i = head[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != hson[u]) sack(v, false);
if (hson[u] != -1) sack(hson[u], true);
if (iedp[u] != 0) {
add_eq(ans, query(w[iedp[u]], tr[u].len));
trie::insert(w[iedp[u]]);
}
for (int i = head[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != hson[u]) {
for (int i = dfn[v], x; i < dfn[v] + sz[v]; ++i)
if (iedp[x = idfn[i]] != 0) add_eq(ans, query(w[iedp[x]], tr[u].len));
for (int i = dfn[v], x; i < dfn[v] + sz[v]; ++i)
if (iedp[x = idfn[i]] != 0) trie::insert(w[iedp[x]]);
}
if (!keep) trie::tr[trie::tn = 1] = trie::node();
} // sack
int main(void) {
scanf("%d%s", &n, str + 1);
last = sn = 0, tr[0].link = -1;
for (int i = 1; i <= n; ++i) extend(str[i] - 'a');
for (int i = 1; i <= sn; ++i) add_edge(tr[i].link, i);
for (int i = 1; i <= n; ++i) scanf("%d", &w[i]);
scanf("%d%d", &wl, &wr);
dn = 0, sack_init(0, 0); sack(0, false);
printf("%d\n", ans);
exit(EXIT_SUCCESS);
} // main