「AHOI 2013」差异【后缀数组】

Time Limit: 40 Sec Memory Limit: 512 MB

Description

给定一个长度为 nn 的字符串 SS,令 TiT_i 表示它从第 ii 个字符开始的后缀。求

1i<jn(len(Ti)+len(Tj)2×lcp(Ti,Tj))\displaystyle \sum_{1 \leqslant i < j \leqslant n} \Big(\text{len}(T_i) + \text{len}(T_j) - 2 \times \text{lcp}(T_i, T_j)\Big)

其中,len(a)\text{len}(a) 表示字符串 aa 的长度,lcp(a,b)\text{lcp}(a, b) 表示字符串 aa 和字符串 bb 的最长公共前缀。

Input

一行,一个字符串 SS

Output

一行,一个整数,表示所求值。

Sample Input

1
cacao

Sample Output

1
54

Constraints

对于 100% 的数据,保证 2n5000002 \leqslant n \leqslant 500000

Solution

对于要求解的表达式,先只考虑前两项,易知

1i<jn(len(Ti)+len(Tj))=i=2n(i(i1)2+i(i1))\displaystyle\sum_{1 \leqslant i < j \leqslant n} \Big(\text{len}(T_i) + \text{len}(T_j)\Big) = \sum_{i=2}^n \Big(\frac{i \cdot (i-1)}{2} + i \cdot (i-1) \Big)

剩下的一项是 2×1i<jnlcp(Ti,Tj)\displaystyle -2 \times \sum_{1 \leqslant i < j \leqslant n} \text {lcp}(T_i, T_j),相当于求 height\text {height} 数组中每个区间的最小值之和。

考虑 height[i]\text{height}[i] 对答案的贡献,可以用单调栈预处理出 height[i]\text{height}[i] 对向左和向右多少后缀有影响(height[i]\text{height}[i] 是这一段的最小值),再从答案中减去即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int maxn = 500010;
char s[maxn];
int n, sa[maxn], rk[maxn], ht[maxn], f[maxn], g[maxn];
int num[maxn], a[maxn];
int fir[maxn], sec[maxn], buc[maxn], tmp[maxn];
ll res = 0;
stack<int> st;

void build_sa() {
copy(s + 1, s + n + 1, num + 1);
sort(num + 1, num + n + 1);
int *end = unique(num + 1, num + n + 1);
for (int i = 1; i <= n; i++) a[i] = lower_bound(num + 1, end, s[i]) - num;
memset(buc, 0, sizeof(buc));
for (int i = 1; i <= n; i++) buc[a[i]]++;
for (int i = 1; i <= n; i++) buc[i] += buc[i - 1];
for (int i = 1; i <= n; i++) rk[i] = buc[a[i] - 1] + 1;
for (int k = 1; k <= n; k <<= 1) {
for (int i = 1; i <= n; i++) fir[i] = rk[i];
for (int i = 1; i <= n; i++) sec[i] = i + k > n ? 0 : rk[i + k];
memset(buc, 0, sizeof(buc));
for (int i = 1; i <= n; i++) buc[sec[i]]++;
for (int i = 1; i <= n; i++) buc[i] += buc[i - 1];
for (int i = 1; i <= n; i++) tmp[n - --buc[sec[i]]] = i;
memset(buc, 0, sizeof(buc));
for (int i = 1; i <= n; i++) buc[fir[i]]++;
for (int i = 1; i <= n; i++) buc[i] += buc[i - 1];
for (int i = 1; i <= n; i++) sa[buc[fir[tmp[i]]]--] = tmp[i];
bool unique = true;
rk[sa[1]] = 1;
for (int i = 1; i <= n; i++) {
rk[sa[i]] = rk[sa[i - 1]];
if (fir[sa[i]] == fir[sa[i - 1]] && sec[sa[i]] == sec[sa[i - 1]]) unique = false;
else rk[sa[i]]++;
}
if (unique) break;
}
for (int i = 1, k = 0; i <= n; i++) {
if (k) k--;
int j = sa[rk[i] - 1];
while (i + k <= n && j + k <= n && a[i + k] == a[j + k]) k++;
ht[rk[i]] = k;
}
}

int main() {
scanf("%s", s + 1);
n = strlen(s + 1);
build_sa();
for (int i = 2; i <= n; i++) {
res += 1LL * i * (i - 1) / 2 + 1LL * i * (i - 1);
}
for (int i = 2; i <= n; i++) {
while (!st.empty() && ht[i] < ht[st.top()]) {
f[st.top()] = i - 1;
st.pop();
}
st.push(i);
}
while (!st.empty()) {
f[st.top()] = n;
st.pop();
}
for (int i = n; i >= 2; i--) {
while (!st.empty() && ht[i] <= ht[st.top()]) {
g[st.top()] = i + 1;
st.pop();
}
st.push(i);
}
while (!st.empty()) {
g[st.top()] = 2;
st.pop();
}
for (int i = 2; i <= n; i++) {
res -= 2LL * ht[i] * (f[i] - i + 1) * (i - g[i] + 1);
}
printf("%lld\n", res);
return 0;
}