【算法】后缀数组(SA)

引言

关于后缀数组的原理,集训队论文中已经将的非常详细了。

但是,其中的代码对于第一接触这个东西的人来说可能就觉得有些恶心了。

所以,我这篇文章主要是讲怎么理解倍增方法的SA的构建过程。

SA的实现

基数排序

很多人理解不了这个是因为不知道基数排序

所以,在学习后缀数组之前请先动手写一个双关键字的基数排序程序:求sa[i]表示第i大的元组在数组中的下标。

然后就可以正式进入SA部分了~

SA

先上代码,然后结合代码讲~

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
void radix_sort() {
// 原理是双关键字的基数排序
// sk为第一关键字,sb为第二关键字排序
// 其中,sb相当于上一轮的sa
int i;
for (i = 0; i < m; i++) acc[i] = 0;
for (i = 0; i < n; i++) acc[sk[i]]++;
for (i = 1; i < m; i++) acc[i] += acc[i - 1];
for (i = n - 1; i >= 0; i--) sa[--acc[sk[sb[i]]]] = sb[i];
}

bool cmp(int *f, int x, int y, int w) {
// 比较字符串的大小,可以直接利用sk数组比较
return f[x] == f[y] && f[x + w] == f[y + w];
}

void suffix_array(int *a) {
for (int i = 0; i < n; i++) {
sk[i] = a[i];
// 为了一般性,最初设置一个基于位置的第二关键字
sb[i] = i;
}
radix_sort();
// 倍增
// p=n时说明已经能够将n个后缀全部区分了
for (int w = 1, p = 1, i; p < n; w <<= 1, m = p) {
// 基数排序,求新的sa
for (p = 0, i = n - w; i < n; i++)
sb[p++] = i;
for (i = 0; i < n; i++)
if (sa[i] >= w) sb[p++] = sa[i] - w;
radix_sort();
// sb存放sk的一份拷贝,因为之后要利用当前sk更新之后的sk
for (i = 0; i < n; i++) sb[i] = sk[i];
// 更新sk
p = 0;
sk[sa[0]] = p++; // sa[0]永远都是n-1,即最后补上的0;而sk[sa[0]]即sk[n-1]=0
for (i = 1; i < n; i++)
// 此处的sb实际意义相当于sk
sk[sa[i]] = cmp(sb, sa[i], sa[i - 1], w) ? p - 1 : p++;
}
}

这个我写的SA的代码,虽然比论文中的长很多,但是更好理解,时间效率也和论文中一样~

大部分代码可以看注释去理解,我只强调几个非常重要的部分:

  • sk的意义类似于rk(rank),但是不同的是,sk中使用的是对应元素的值作为排名,所以经常会出现sk中两个元素(甚至多个元素)值相同的情况;而根据rk的定义,rk是和sa互为逆运算,所以,必然是一对一映射
  • sb的意义相当于sa,sb中任意两个元素不相同
  • 因为,我按照论文中的方法,在字符串后面添加了一个0,这样会导致sa[0]永远都是n-1,sk[sa[0]]即sk[n-1]=0
  • 双关键字的基数排序理论上需要进行两次排序过程,但是实际上,代码中只进行了一次。之所以只进行了一次,是因为第二关键字已经排好顺序了,按照i从小到大的顺序访问sb[i](即上一轮的sk[i]),即可满足第二关键字由小到大排列的要求
  • p<n为结束条件是因为,p代表sk中不同数字的个数,所以,当p=n时,相当于没有并列的情况出现,即n个后缀已经全部区分开了~

应用

上面求sa和rk还无法体现SA的强大之处。因为,SA还有一个很强大的辅助数组:height。

关于hight数组,可以参考上面提到的论文,这部分无论是原理还是代码都还是比较好理解的。

应用方面主要也是关于各种各样的匹配问题的,还是参看论文吧~

这里给出poj1743中应用SA的代码:

因为我是用Jetbrain家的CLion写的,为了方便,就只能用名字空间了,对于单个代码其实没什么用,所以请忽略啦…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>

namespace POJ1743 {
using namespace std;

# define N 20020

// sk中的两个元素可以相同,但是sa,sb中的不可以相同
// 所以,实际上sk和sa并不是互逆的
int sk[N], rk[N], sa[N], sb[N], ht[N];
int acc[N];

int n, m, nn;
int a[N];
int b[N];

void radix_sort() {
// 原理是双关键字的基数排序
// sk为第一关键字,sb为第二关键字排序
// 其中,sb相当于上一轮的sa
int i;
for (i = 0; i < m; i++) acc[i] = 0;
for (i = 0; i < n; i++) acc[sk[i]]++;
for (i = 1; i < m; i++) acc[i] += acc[i - 1];
for (i = n - 1; i >= 0; i--) sa[--acc[sk[sb[i]]]] = sb[i];
}

bool cmp(int *f, int x, int y, int w) {
// 比较字符串的大小,可以直接利用sk数组比较
return f[x] == f[y] && f[x + w] == f[y + w];
}

void suffix_array(int *a) {
for (int i = 0; i < n; i++) {
sk[i] = a[i];
// 为了一般性,最初设置一个基于位置的第二关键字
sb[i] = i;
}
radix_sort();
// 倍增
// p=n时说明已经能够将n个后缀全部区分了
for (int w = 1, p = 1, i; p < n; w <<= 1, m = p) {
// 基数排序,求新的sa
for (p = 0, i = n - w; i < n; i++)
sb[p++] = i;
for (i = 0; i < n; i++)
if (sa[i] >= w) sb[p++] = sa[i] - w;
radix_sort();
// sb存放sk的一份拷贝,因为之后要利用当前sk更新之后的sk
// 而且,在下一轮循环中,sb利用sk的值可以直接排序
for (i = 0; i < n; i++) sb[i] = sk[i];
// 更新sk
p = 0;
sk[sa[0]] = p++; // sa[0]永远都是n-1,即最后补上的0;而sk[sa[0]]即sk[n-1]=0
for (i = 1; i < n; i++)
// 此处的sb实际意义相当于sk
sk[sa[i]] = cmp(sb, sa[i], sa[i - 1], w) ? p - 1 : p++;
}
}

void get_ht() {
int i, j, k = 0;
// sk和sa是互逆的(rk[sa[0]]=0)
for (i = 1; i <= nn; i++) rk[sa[i]] = i;
// h[i] >= h[i-1]-1,其中h[i]=ht[rk[i]]
for (i = 0; i < nn; ht[rk[i++]] = k)
for (k ? k-- : 0, j = sa[rk[i] - 1]; b[i + k] == b[j + k]; k++);
}

bool check(int x) {
int mx = sa[1], mn = sa[1];
for (int i = 2; i < nn; i++) {
if (ht[i] < x) mx = mn = sa[i];
else {
mx = max(mx, sa[i]);
mn = min(mn, sa[i]);
if (mx - mn >= x) return true;
}
}
return false;
}

void solve() {
while (scanf("%d", &nn), nn) {
for (int i = 0; i < nn; i++) {
scanf("%d", &a[i]);
}
nn -= 1;
for (int i = 0; i < nn; i++) {
b[i] = a[i] - a[i + 1] + 88;
}
m = 190;
// 将字符串拓展一位,保证cmp函数中不会越界
n = nn + 1;
b[n - 1] = 0;
suffix_array(b);
get_ht();

int ans = 0;
int l = 4, r = nn / 2 + 1, mid;
while (l <= r) {
mid = (l + r) >> 1;
if (check(mid)) ans = mid, l = mid + 1;
else r = mid - 1;
}
if (ans < 4) puts("0");
else printf("%d\n", ans + 1);
}
}
}

int main() {
POJ1743::solve();
return 0;
}