UOJ-581 NOIP2020 字符串匹配

Description

给定小写字母组成的字符串 \(S\)。定义 \(AB\) 表示字符串 \(A, B\) 拼接,\(A^n=A^{n-1}A\) 表示 \(A\) 复制 \(n\) 遍。求三元组 \((A, B, C)\) 的个数,满足 \(S\) 可以写成 \((AB)^i C\) 的形式。共 \(T\) 组数据。

Constraints

\(1\le |S| \le 2^{20}, 1\le T\le 5\)​。

Solution

首先是一个比较大众也比较好想的做法。记 \(pre(i)\)​ 表示前缀 \(i\)​ 中出现奇数次的字符个数,同理对后缀定义 \(suf(i)\)​。字符集用 \(\Sigma\) 表示。

考虑枚举 \(AB\)​​ 的长 \(x\)​​,那么前缀 \(S[1:x]\)​​ 就是 \(AB\)​​。考虑 \(S\)​​ 将会由 \(AB\)​​ 循环若干次构成,剩下的就是 \(C\)​​,那么考虑找到一个最大的循环次数 \(k\)​​,哈希即可。找到 \(k\)​​ 之后,就能对于每个循环次数 \(i\in [1, k]\)​​,求出 \(pre\)​ 上 \([1, ix)\)​ 中有多少个 \(\le suf(ix+1)\)​ 就是对答案的贡献。考虑到值域是 \([0, 26]\)​,树状数组维护单次操作是 \(O(\log |\Sigma|)\)​。直接做是 \(O(T\sum_{i=1}^n\tfrac n i) = O(Tn\log n\log |\Sigma|)\)​ 的,想要通过比较困难。

优化其实并不难,考虑一下两个要点:

  • \(k\)​ 的合法性是单调的;
  • 对于一个 \(k\),不需要枚举 \(i\),奇数偶数分开算,同为奇数或偶数贡献是一样的。

第二个比较简单。第一个我们可以考虑二分或者倍增找到 \(k\)。这样的话复杂度大概是 \(O(\sum_{i=1}^n\log (\tfrac n i))\approx O(n)\)。参考 这里

二分我不太会保证复杂度,这里介绍一种倍增方法。

设 \(X=S[1:x]\),其哈希值为 \(H(X)\)。那么我们可以得到复制 \(t\) 倍的串的哈希值:

\[H(X^t)=\sum_{i=0}^{k-1}b^{ix} H(X)=\frac{b^{tx}-1}{b^x-1}H(X) \]

其中 \(b\)​ 为哈希的基数。计算一个哈希值,如果使用快速幂的话,需要 \(O(\log tx)\)​ 的时间。不过如果是倍增的话,我们只需要计算 \(\lfloor\log_2\tfrac n x\rfloor\)​ 个 \(b^{tx}\)​ 值即可,每一项等于前一项的平方。逆元直接算是 \(O(\log \bmod)\)​ 的,尽管每个 \(x\)​ 都只算一次也是不可接受的。那么只好用一个 离线求逆元的 trick,预处理所有 \(x\)​ 的 \((b^x-1)^{-1}\)​。这样复杂度就只有 \(O(Tn\log |\Sigma|)\)​ 了。不会二分是因为倍增可以预处理 \(\lfloor\log_2\tfrac n x\rfloor\)​ 个 \(b^{tx}\) 而二分我就不知道了。

Code

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
const int N = 1 << 20 | 5;

typedef unsigned long long ull;
const ull base = 19260817;
const ull mod = 1e9 + 7;
ull pw[N];

int n;
long long ans;
char s[N];
int pre[N], suf[N];
ull hs[N];

inline ull fastpow(ull a, ull b) {
  ull r = 1llu;
  for (; b; b >>= 1, (a *= a) %= mod)
    if (b & 1) (r *= a) %= mod;
  return r;
}

ull buf[N], inv[N];
namespace inversion {
  ull pre[N], suf[N];
  void process(int n) {
    memset(pre, 0, sizeof(pre));
    memset(suf, 0, sizeof(suf));
    memset(inv, 0, sizeof(inv));
    pre[0] = suf[n + 1] = 1llu;
    for (int i = 1; i <= n; i++)
      pre[i] = pre[i - 1] * buf[i] % mod;
    for (int i = n; i >= 1; i--)
      suf[i] = suf[i + 1] * buf[i] % mod;
    ull all = fastpow(pre[n], mod - 2);
    for (int i = 1; i <= n; i++)
      inv[i] = pre[i - 1] * suf[i + 1] % mod * all % mod;
  }
}

struct bit {
  int t[28];
  inline int get(int x) {
    int r = 0;
    for (++x; x; x -= x & -x) r += t[x];
    return r;
  }
  inline void add(int x) {
    for (++x; x <= 27; x += x & -x) ++t[x];
  }
  inline void reset() {
    memset(t, 0, sizeof(t));
  }
} tr;

signed main() {
  pw[0] = 1llu;
  for (int i = 1; i < N; i++)
    pw[i] = pw[i - 1] * base % mod;

  int T;
  scanf("%d", &T);
  while (T--) {
    scanf("%s", s + 1);
    n = strlen(s + 1);
    ans = 0;
    tr.reset();

    memset(hs, 0, sizeof(hs));
    memset(pre, 0, sizeof(pre));
    memset(suf, 0, sizeof(suf));
    memset(buf, 0, sizeof(buf));

    for (int i = 1; i <= n; i++)
      hs[i] = (hs[i - 1] * base + s[i]) % mod;

    pre[0] = suf[n + 1] = 0;
    for (int i = 1, v = 0; i <= n; i++) {
      int nv = v ^ (1 << (s[i] - 'a'));
      if (nv > v) pre[i] = pre[i - 1] + 1;
      else pre[i] = pre[i - 1] - 1;
      v = nv;
    }
    for (int i = n, v = 0; i >= 1; i--) {
      int nv = v ^ (1 << (s[i] - 'a'));
      if (nv > v) suf[i] = suf[i + 1] + 1;
      else suf[i] = suf[i + 1] - 1;
      v = nv;
    }
    
    for (int x = 2; x < n; x++)
      buf[x - 1] = pw[x] - 1;
    inversion::process(n - 2);
    tr.add(pre[1]);
    for (int x = 2; x < n; x++) {
      int k = 0, maxb = log2(n / x);
      ull cst = inv[x - 1] * hs[x] % mod;
      ull fix = 0;

      ull tpw[maxb + 1];
      tpw[0] = pw[x];
      for (int j = 1; j <= maxb; j++)
        tpw[j] = tpw[j - 1] * tpw[j - 1] % mod;

      for (int j = maxb; j >= 0; j--) {
        ull cur = ((tpw[j] - 1) * cst % mod * pw[k * x] % mod + fix) % mod;
        if (cur == hs[x * (k + (1 << j))])
          k += (1 << j), fix = cur;
      }
      
      if (x * k == n) --k;
      ans += tr.get(suf[x + 1]) * ((k + 1) / 2);
      if (k > 1) ans += tr.get(suf[x * 2 + 1]) * (k / 2);
      tr.add(pre[x]);
    }

    printf("%lld\n", ans);
  }
  return 0;
}
上一篇:P3706-[SDOI2017]硬币游戏【高斯消元,字符串hash】


下一篇:题解 CF613E Puzzle Lover