后缀数组简介

前置知识

后缀数组模板

什么是后缀数组

后缀数组的核心是两个数组:\(sa,rk\)。

将字符串 \(s\) 的所有后缀从小到大排序,\(sa[i]\) 表示排名第 \(i\) 的后缀的起始字符在原字符串中的下标(从 \(1\) 开始编号),\(rk[i]\) 表示起始位置为 \(i\) 的后缀的排名。容易发现,在已知 \(sa,rk\) 中的一个数组时,可以得到另一个数组,即 \(rk[sa[i]]=i;sa[rk[j]]=j\)。

后缀数组的求法

后缀数组最常用、实用性最广的求法是倍增。我们用 \(sa_w[i],rk_w[i]\) 表示 \(\{s[x...\min(n,x+w-1)]|x=[1,n]\}\) 中,排名第 \(i\) 的子串的起始字符在原字符串中的下标,和起始位置为 \(i\) 的子串的排名。

类似倍增,在已知 \(rk_{w}[i],rk_{w}[i+w]\) 的情况下,要在求出 \(sa_{2w}\) 的这一轮中将这些字符串排序,只需将所有“\(s[i...i+2w-1]\)”串按 \(rk_w[i]\) 为第一关键字,\(rk_w[i+w]\) 为第二关键字排序即可。(对于 \(x>w\) 的 \(rk[x]\) 可以视作无穷小)

接下来求 \(rk_{2w}\),我们从 \(1\) 到 \(n\) 枚举排名 \(i\),如果排名 \(i\) 的 len=\(2w\) 子串和排名 \(i-1\) 的全等(即 \(rk_w[sa_{2w}[i]]=rk_w[sa_{2w}[i]+w]\)),那么他们应该拥有相同的排名,否则排名++。

因此核心在于双关键字排序的实现方式,用朴素 sort 是 \(O(n\log n)\),用计数排序+基数排序是 \(O(n)\) 的。

sort 代码:
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

const int N = 1000010;

char s[N];
int n, w, sa[N], rk[N << 1], oldrk[N << 1];
// 为了防止访问 rk[i+w] 导致数组越界,开两倍数组。
// 当然也可以在访问前判断是否越界,但直接开两倍数组方便一些。

int main() {
  int i, p;

  scanf("%s", s + 1);
  n = strlen(s + 1);
  for (i = 1; i <= n; ++i) sa[i] = i, rk[i] = s[i];

  for (w = 1; w < n; w <<= 1) {
    sort(sa + 1, sa + n + 1, [](int x, int y) {
      return rk[x] == rk[y] ? rk[x + w] < rk[y + w] : rk[x] < rk[y];
    });  // 这里用到了 lambda
    memcpy(oldrk, rk, sizeof(rk));
    // 由于计算 rk 的时候原来的 rk 会被覆盖,要先复制一份
    for (p = 0, i = 1; i <= n; ++i) {
      if (oldrk[sa[i]] == oldrk[sa[i - 1]] &&
          oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) {
        rk[sa[i]] = p;
      } else {
        rk[sa[i]] = ++p;
      }  // 若两个子串相同,它们对应的 rk 也需要相同,所以要去重
    }
  }

  for (i = 1; i <= n; ++i) printf("%d ", sa[i]);

  return 0;
}
计数排序+基数排序代码:
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

const int N = 1000010;

char s[N];
int n, sa[N], rk[N << 1], oldrk[N << 1], id[N], cnt[N];

int main() {
  int i, m, p, w;

  scanf("%s", s + 1);
  n = strlen(s + 1);
  m = max(n, 125);
  for (i = 1; i <= n; ++i) sa[i] = i, rk[i] = s[i]; //这里为了方便直接把rk赋成ASCII,虽然排名可能不连续,但后面总是要排序的


  for (w = 1; w < n; w <<= 1) {
    memset(cnt, 0, sizeof(cnt));
    for (i = 1; i <= n; ++i) id[i] = sa[i];
    for (i = 1; i <= n; ++i) ++cnt[rk[id[i] + w]];
    for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
    for (i = n; i >= 1; --i) sa[cnt[rk[id[i] + w]]--] = id[i];
    memset(cnt, 0, sizeof(cnt));
    for (i = 1; i <= n; ++i) id[i] = sa[i];
    for (i = 1; i <= n; ++i) ++cnt[rk[id[i]]];
    for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
    for (i = n; i >= 1; --i) sa[cnt[rk[id[i]]]--] = id[i];
    memcpy(oldrk, rk, sizeof(rk));
    for (p = 0, i = 1; i <= n; ++i) {
      if (oldrk[sa[i]] == oldrk[sa[i - 1]] &&
          oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) {
        rk[sa[i]] = p;
      } else {
        rk[sa[i]] = ++p;
      }
    }
  }

  for (i = 1; i <= n; ++i) printf("%d ", sa[i]);

  return 0;
}

常数优化

上述代码会 TLE,因为常数较大。

方式一 减少不连续的内存访问次数if (oldrk[sa[i]] == oldrk[sa[i - 1]] && oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) 中 \(sa[i],sa[i-1]\) 都访问了 2 次, for (i = 1; i <= n; ++i) ++cnt[rk[id[i]]]; for (i = n; i >= 1; --i) sa[cnt[rk[id[i]]]--] = id[i]; 中 \(rk[id[i]]\) 访问了两次,如果能将 \(rk[id[1\sim n]]\) 存下来、让 \(sa[i],sa[i-1]\) 只访问一次将减小很多常数。该优化方式与计算机内部数据存储方式有关。具体实现见下方完整代码。

方式二 当已经求出最终的 \(sa\) 数组时就跳出 \(w\) 循环:在 \(w\) 循环末尾加入 if(p==n){for(i=1;i<=n;i++)sa[rk[i]]=i;break;}。原因是我们最后的答案一定是排名互不相同,现在虽然还是“子串”不是要求的后缀,但是既然他们目前的排名已经互不相同了那么把它后面再加一些字符已经不会影响他的排名了。

方式三 优化计数排序值域:在前一轮的求 \(rk\) 中我们已经求出一个 \(p\),所以这一轮的值域 \(m=p\)。

【模板】LOJ#111 后缀排序,完整代码:

#include <bits/stdc++.h>
using namespace std;
const int N=1e6+5;
char s[N];
int n,m,sa[N<<1],rk[N<<1],oldrk[N<<1],rkid[N<<1],cnt[N],id[N];
bool cmp(int x,int y,int w){return oldrk[x]==oldrk[y]&&oldrk[x+w]==oldrk[y+w];}
int main()
{
	scanf("%s",s+1);
	n=strlen(s+1);
	m=max(125,n);
	for(int i=1;i<=n;i++)sa[i]=i,rk[i]=s[i];
	for(int w=1;w<n;w<<=1){
		memset(cnt,0,sizeof(cnt));
		for(int i=1;i<=n;i++)id[i]=sa[i];
		for(int i=1;i<=n;i++)cnt[rk[id[i]+w]]++;
		for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
		for(int i=n;i>=1;i--)sa[cnt[rk[id[i]+w]]--]=id[i];
		memset(cnt,0,sizeof(cnt));
		for(int i=1;i<=n;i++)id[i]=sa[i],rkid[i]=rk[id[i]];
		for(int i=1;i<=n;i++)cnt[rkid[i]]++;
		for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
		for(int i=n;i>=1;i--)sa[cnt[rkid[i]]--]=id[i];
		memcpy(oldrk,rk,sizeof(rk));
		int p=0;
		for(int i=1;i<=n;i++)rk[sa[i]]=cmp(sa[i],sa[i-1],w)?p:++p;
		if(p==n){for(int i=1;i<=n;i++)sa[rk[i]]=i;break;}
	}
	for(int i=1;i<=n;i++)printf("%d ",sa[i]);
}

练习题

参考资料

上一篇:[转]opencv二值化的cv2.threshold函数


下一篇:UVa 1630 - Folding (区间dp)