Splay
1. Splay原理
原理
-
Splay是一颗二叉树,期望高度是
log(n)
的,但是不是完全平衡的。 -
和其他的平衡树一样,也存在基本的左旋和右旋操作,左旋和右旋示意图如下:
- 该函数实现如下:
/*
对节点x进行向右旋转操作(//是需要变更的关系)
z z
/ \ // \ (1)
y D 向右旋转 (x) x D
/ \ - - - - - - - -> / \\ (2)
x C A y
/ \ // \ (3)
A B B C
*/
// 根据x与其父节点y,以及y的父节点z之间的关系对x进行旋转
// 存在两种情况(左旋、右旋),可以写成一个函数
void rotate(int x)
{
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x; // k=0表示x是y的左儿子;k=1表示x是y的右儿子
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z; // 变更(1)
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y; // 变更(3)
tr[x].s[k ^ 1] = y, tr[y].p = x; // 变更(2)
pushup(y), pushup(x);
}
-
Splay
的核心函数是splay(x, k)
函数,其含义是将x
旋转到节点k
的下面,一共存在四种情况,因为是对称的,这里给出两种情况的旋转方式,如下图:
- 该函数实现如下:
// 存在四种情况
/*
z z z z
/ / \ \
y y y y
/ \ / \
x x x x
*/
void splay(int x, int k)
{
while (tr[x].p != k) // x还未转为k的子节点
{
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x); // 说明是折线情况
else rotate(y);
rotate(x);
}
if (!k) root = x; // k=0说明将当前节点转到根节点
}
-
如果要将某段连续的数据插入
Splay
中,我们首先找到需要插入的位置的前驱和后继节点y、z
,然后使用splay(y, 0)
将y
转到根节点,使用splay(z, y)
将z
转成y
的儿子,此时z
的左儿子一定为空,将需要插入的数据构造成一棵树插到z
的左儿子上即可。 -
如果要删除一段数据,则找到这段数据在树中对应节点编号的端点的前一个节点
l
、以及后一个节点r
,然后使用splay(l, 0)
将l
转到根节点,使用splay(r, l)
将r
转成l
的儿子,最后将r
的左孩子删除即可。 -
注意:
splay
不能保证中序遍历得到的序列是有序的,其维护的是:splay
中序遍历始终是当前序列的顺序。
2. AcWing上的Splay题目
AcWing 2437. Splay
问题描述
-
问题链接:AcWing 2437. Splay
分析
-
首先考虑如何构建
splay
,因为这里输入的序列是1~n
,是有序的,因此可以每次根据值向树中插入一个数据,保证该树中序遍历是有序的即可。 -
另外一种构建树的方式:类似于线段树构建的方式,每次取序列中间的数据作为根节点,该数据左边和右边分别为左右子树,构建二叉树即可。
-
因为当我们操作的是第一个点或者最后一个点,可能越界,因此需要两个哨兵,一个是负无穷,另一个是正无穷。
-
因为每次需要找到旋转的一段连续区间在
splay
中的位置,我们需要实现get_k(int k)
函数返回中序遍历中第k
个节点的编号。 -
假设我们要旋转原序列
[l, r]
之间的数据,则需要找到原序列第l-1
和第r+1
个数据在树中的位置,因为存在哨兵,原序列第l-1
个位置在树中的节点编号可以使用get(l)
得到,同理第r+1
个数据在树中的位置可以使用get(r+2)
得到。 -
考虑
splay
中每个节点需要记录的信息,因为要得到在中序遍历中第k
个数据,因此需要记录以每个节点为根的子树中节点的个数size
;另外还要有个懒标记flag
,表示以当前节点为根的子树是否需要翻转。 -
每次在左旋或者右旋后需要调用
pushup
更新size
;在每次递归splay
前需要调用pushdown
将懒标记下传。
代码
- C++
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 100010, INF = 1e9;
int n, m;
struct Node
{
int s[2], p, v; // 孩子, 父节点, 节点对应的值
int size; // 以当前节点为根的子树中节点个数
int flag; // 懒标记, 表示当前节点的子树是否需要旋转
void init(int _v, int _p)
{
v = _v, p = _p;
size = 1;
}
}tr[N];
int root, idx;
int w[N]; // 序列,这里是1~n
void pushup(int x)
{
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}
void pushdown(int x)
{
if (tr[x].flag)
{
swap(tr[x].s[0], tr[x].s[1]);
tr[tr[x].s[0]].flag ^= 1;
tr[tr[x].s[1]].flag ^= 1;
tr[x].flag = 0;
}
}
/*
对节点x进行向右旋转操作(//是需要变更的关系)
z z
/ \ // \ (1)
y D 向右旋转 (x) x D
/ \ - - - - - - - -> / \\ (2)
x C A y
/ \ // \ (3)
A B B C
*/
// 根据x与其父节点y,以及y的父节点z之间的关系对x进行旋转
// 存在两种情况(左旋、右旋),可以写成一个函数
void rotate(int x)
{
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x; // k=0表示x是y的左儿子;k=1表示x是y的右儿子
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z; // 变更(1)
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y; // 变更(3)
tr[x].s[k ^ 1] = y, tr[y].p = x; // 变更(2)
pushup(y), pushup(x);
}
// 存在四种情况
/*
z z z z
/ / \ \
y y y y
/ \ / \
x x x x
*/
void splay(int x, int k)
{
while (tr[x].p != k) // x还未转为k的子节点
{
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x); // 说明是折线情况
else rotate(y);
rotate(x);
}
if (!k) root = x; // k=0说明将当前节点转到根节点
}
// 返回splay的中序遍历中第k个节点(从1开始)对应的节点编号
int get_k(int k)
{
int u = root;
int cnt = 0;
while (true)
{
pushdown(u);
if (tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
else if (tr[tr[u].s[0]].size + 1 == k) return u;
else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
}
return -1;
}
void output(int u)
{
pushdown(u);
if (tr[u].s[0]) output(tr[u].s[0]);
if (tr[u].v > -INF && tr[u].v < INF) printf("%d ", tr[u].v);
if (tr[u].s[1]) output(tr[u].s[1]);
}
int build(int l, int r, int p)
{
int mid = l + r >> 1;
int u = ++idx;
tr[u].init(w[mid], p);
if (l < mid) tr[u].s[0] = build(l, mid - 1, u);
if (mid < r) tr[u].s[1] = build(mid + 1, r, u);
pushup(u);
return u;
}
int main()
{
scanf("%d%d", &n, &m);
w[0] = -INF, w[n + 1] = INF; // 哨兵
for (int i = 1; i <= n; i ++ ) w[i] = i;
root = build(0, n + 1, 0); // 构建splay
while (m -- )
{
int l, r;
scanf("%d%d", &l, &r);
l = get_k(l), r = get_k(r + 2);
splay(l, 0), splay(r, l);
tr[tr[r].s[0]].flag ^= 1;
}
output(root);
return 0;
}
AcWing 950. 郁闷的出纳员
问题描述
-
问题链接:AcWing 950. 郁闷的出纳员
分析
- 本题我们可以对序列进行的操作是:
(1)向序列添加一个数;
(2)当前序列中的所有的数据加上一个数;
(3)当前序列中的所有的数据减去一个数;
(4)删除序列中小于某个值m
的所有数据。
(5)返回第k
大的数据。
-
可以使用
splay
维护整个序列,本题维护的序列中序遍历是有序的。考虑如何实现这几个操作: -
对于(1),直接将给定数据插入树中即可。
-
对于(2)(3),可以使用一个全局变量
delta
记录当前工资的增加或者减少的值;某个原始值为x
的话,则实际对应的值为x+delta
,如果x+delta<m
,则需要将x
从序列中删除,即删除所有小于m-delta
的数据。注意加入delta
后操作(1)插入的值要变为x-delta
,这样实际插入的值是x
。 -
对于(4),因为维护的是升序序列,我们删除的一定是最靠前的一段,其中
R
是大于等于m-delta
最小的数对应的节点编号,如下图:
-
我们找到需要删除的一段的前驱
L
和后继R
,然后将R
转到根节点,L
转为R
的左孩子,然后删除L
的右孩子即可。 -
对于(5),我们需要记录
splay
中每棵子树的节点数size
,我们需要实现get_k
,返回第k
小的数。那么如何求解第k
大的数呢?因为有哨兵,我们需要返回第k+1
大的数,如果一共当前序列中有s
个数据,则返回第s-k
小的数即可,记得最后结果还要加上delta
。
代码
- C++
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 100010, INF = 1e9;
int n, m; // 操作数、下界
int delta; // 工资变化量
struct Node {
int s[2], p, v; // 本题splay根据v排序
int size;
void init(int _v, int _p) {
v = _v, p = _p;
size = 1;
}
} tr[N];
int root, idx;
void pushup(int u) {
tr[u].size = tr[tr[u].s[0]].size + tr[tr[u].s[1]].size + 1;
}
void rotate(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(y);
rotate(x);
}
if (!k) root = x;
}
int insert(int v) {
int u = root, p = 0;
while (u) p = u, u = tr[u].s[v > tr[u].v];
u = ++idx;
if (p) tr[p].s[v > tr[p].v] = u;
tr[u].init(v, p);
// 为了保证splay高度是log级别的,同时调用pushup更新节点
// 使用下面注释的四行逻辑也正确,但是会超时
splay(u, 0);
// if (!root) root = 1;
// while (tr[u].p) {
// pushup(tr[u].p);
// u = tr[u].p;
// }
return u;
}
// 返回值大于等于v的最小的数对应的节点编号
int get(int v) {
int u = root, res = 0;
while (u) {
if (tr[u].v >= v) res = u, u = tr[u].s[0];
else u = tr[u].s[1];
}
return res;
}
// 返回第k小的数据
int get_k(int k) {
int u = root;
while (u) {
if (tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
else if (tr[tr[u].s[0]].size + 1 == k) return tr[u].v;
else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
}
return -1;
}
int main() {
scanf("%d%d", &n, &m);
int L = insert(-INF), R = insert(INF);
int tot = 0;
while (n--) {
char op[2];
int k;
scanf("%s%d", op, &k);
if (*op == 'I') {
if (k >= m) k -= delta, insert(k), tot++;
} else if (*op == 'A') {
delta += k;
} else if (*op == 'S') {
delta -= k;
R = get(m - delta); // 获取大于等于m-delta的节点
splay(R, 0), splay(L, R);
tr[L].s[1] = 0; // 删除L的右子树
pushup(L), pushup(R);
} else {
if (tr[root].size - 2 < k) puts("-1");
else printf("%d\n", get_k(tr[root].size - k) + delta);
}
}
printf("%d\n", tot - (tr[root].size - 2)); // 输出离职员工数目
return 0;
}
AcWing 1063. 永无乡
问题描述
-
问题链接:AcWing 1063. 永无乡
分析
-
本题需要支持将两个集合合并,还要支持求某个集合中都
k
小的数据。 -
思路:使用并查集记录岛屿的连通性,使用
splay
记录每个集合。 -
每次合并两个集合时,直接暴力将某个
splay
合并到另一个splay
即可。这里采用启发式合并,每次将元素较少的集合合并到元素较多的集合,这样合并操作可以保证时间复杂度是 O ( n × l o g ( n ) ) O(n \times log(n)) O(n×log(n)),加上splay
插入的时间复杂度,因此时间复杂度为 O ( n × l o g 2 ( n ) ) O(n \times log ^2 (n)) O(n×log2(n))。
代码
- C++
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1800010;
int n, m;
struct Node {
int s[2], p, v;
int id; // 岛屿的编号
int size; // 以当前节点为根的子树节点数目
void init(int _v, int _id, int _p) {
v = _v, p = _p, id = _id;
size = 1;
}
} tr[N];
int root[N], idx; // root: 每个splay的根节点编号
int p[N]; // 并查集, 集合中的祖宗节点存储的是splay的根节点编号
int find(int x) {
if (x != p[x]) p[x] = find(p[x]);
return p[x];
}
void pushup(int u) {
tr[u].size = tr[tr[u].s[0]].size + tr[tr[u].s[1]].size + 1;
}
void rotate(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int x, int k, int b) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(x);
rotate(x);
}
if (!k) root[b] = x;
}
// 获取根节点编号为root[b]的splay中序遍历中第k个节点的id
int get_k(int k, int b) {
int u = root[b];
while (u) {
if (tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
else if (tr[tr[u].s[0]].size + 1 == k) return tr[u].id;
else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
}
return -1;
}
// 将节点重要度为v, 岛屿的编号为id的节点插入到以root[b]为根节点的splay中
void insert(int v, int id, int b) {
int u = root[b], p = 0;
while (u) p = u, u = tr[u].s[v > tr[u].v];
u = ++idx;
if (p) tr[p].s[v > tr[p].v] = u;
tr[u].init(v, id, p);
splay(u, 0, b);
}
// 将根节点编号为u的splay合并到根节点为root[b]的splay中
void dfs(int u, int b) {
if (tr[u].s[0]) dfs(tr[u].s[0], b);
if (tr[u].s[1]) dfs(tr[u].s[1], b);
insert(tr[u].v, tr[u].id, b);
}
int main() {
scanf("%d%d", &n, &m);
// 初始化并查集
for (int i = 1; i <= n; i++) p[i] = i;
// 初始化n个splay
for (int i = 1; i <= n; i++) {
root[i] = i;
int v;
scanf("%d", &v);
tr[i].init(v, i, 0); // 值为v, 岛屿编号为i, 没有父节点
}
idx = n;
// 根据初始边合并splay
while (m--) {
int a, b;
scanf("%d%d", &a, &b);
a = find(a), b = find(b);
if (a != b) {
if (tr[root[a]].size > tr[root[b]].size) swap(a, b);
dfs(root[a], b); // 将root[a]合并到root[b]中
p[a] = b;
}
}
// 操作
scanf("%d", &m);
while (m--) {
char op[2];
int a, b;
scanf("%s%d%d", op, &a, &b);
if (*op == 'B') {
a = find(a), b = find(b);
if (a != b) {
if (tr[root[a]].size > tr[root[b]].size) swap(a, b);
dfs(root[a], b);
p[a] = b;
}
} else {
a = find(a);
if (tr[root[a]].size < b) puts("-1");
else printf("%d\n", get_k(b, a));
}
}
return 0;
}
AcWing 955. 维护数列
问题描述
-
问题链接:AcWing 955. 维护数列
分析
-
为了求解最大连续子序和,需要维护如下变量:
sum、ms、ls、rs
,分别表示:序列和、最大子序和、最大前缀和、最大后缀和。 -
对于修改和翻转操作,需要使用懒标记,分别使用
rev、same
记录,当区间中的值改变后,我们一定要将、更新区间中维护的信息。 -
求最大子序和:直接返回根节点的
ms
值即可。其余的五个操作过程类似,都是找到序列的前驱l
和后继r
,然后将l
旋转到根节点,r
旋转成l
的子节点,然后对r
的左孩子进行操作即可。 -
另外注意:本题中间用到的点数可能非常多,因此需要使用垃圾回收机制,将每次删除的节点编号回收。
代码
- C++
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 500010, INF = 1e9;
int n, m;
struct Node {
int s[2], p, v; // 左右孩子、父节点、值
int rev; // 懒标记, 以当前节点为根的子树是否需要翻转
int same; // 懒标记, 以当前节点为根的子树是否需要变成和根节点相同的数
int size; // 以当前节点为根的子树节点数目
int sum; // 以当前节点为根的子树节点和
int ms; // 最大连续字段和(至少包含一个数)
int ls; // 最大前缀和
int rs; // 最大后缀和
void init(int _v, int _p) {
s[0] = s[1] = 0, p = _p, v = _v;
rev = same = 0;
size = 1, sum = ms = v;
ls = rs = max(v, 0);
}
} tr[N];
int root;
int nodes[N], tt; // 内存回收机制, 存储可用的编号
int w[N];
// 使用u的孩子更新节点u的信息
void pushup(int u) {
auto &root = tr[u], &left = tr[root.s[0]], &right = tr[root.s[1]];
root.size = left.size + right.size + 1;
root.sum = left.sum + right.sum + root.v;
root.ls = max(left.ls, left.sum + root.v + right.ls);
root.rs = max(right.rs, right.sum + root.v + left.rs);
root.ms = max(max(left.ms, right.ms), left.rs + root.v + right.ls);
}
// 将懒标记下传
void pushdown(int u) {
auto &root = tr[u], &left = tr[root.s[0]], &right = tr[root.s[1]];
if (root.same) {
root.same = root.rev = 0;
if (root.s[0]) left.same = 1, left.v = root.v, left.sum = left.v * left.size;
if (root.s[1]) right.same = 1, right.v = root.v, right.sum = right.v * right.size;
if (root.v > 0) {
if (root.s[0]) left.ms = left.ls = left.rs = left.sum;
if (root.s[1]) right.ms = right.ls = right.rs = right.sum;
} else {
if (root.s[0]) left.ms = left.v, left.ls = left.rs = 0;
if (root.s[1]) right.ms = right.v, right.ls = right.rs = 0;
}
} else if (root.rev) {
root.rev = 0, left.rev ^= 1, right.rev ^= 1;
swap(left.ls, left.rs), swap(right.ls, right.rs);
swap(left.s[0], left.s[1]), swap(right.s[0], right.s[1]);
}
}
void rotate(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(y);
rotate(x);
}
if (!k) root = x;
}
int get_k(int k) {
int u = root;
while (u) {
pushdown(u);
if (tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
else if (tr[tr[u].s[0]].size + 1 == k) return u;
else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
}
return -1;
}
// 将w[l...r]构造成splay, 并返回根节点编号
// p为当前节点的父节点编号,不存在的话用0表示
int build(int l, int r, int p) {
int mid = l + r >> 1;
int u = nodes[tt--]; // 为当前节点分配编号
tr[u].init(w[mid], p);
if (l < mid) tr[u].s[0] = build(l, mid - 1, u);
if (mid < r) tr[u].s[1] = build(mid + 1, r, u);
pushup(u);
return u;
}
// 回收以u为根节点的子树的节点编号
void dfs(int u) {
if (tr[u].s[0]) dfs(tr[u].s[0]);
if (tr[u].s[1]) dfs(tr[u].s[1]);
nodes[++tt] = u;
}
// void output(int u)
// {
// pushdown(u);
// if (tr[u].s[0]) output(tr[u].s[0]);
// if (tr[u].v != -INF) printf("%d ", tr[u].v);
// if (tr[u].s[1]) output(tr[u].s[1]);
// }
int main() {
// 待分配的节点
for (int i = 1; i < N; i++) nodes[++tt] = i;
scanf("%d%d", &n, &m);
tr[0].ms = -INF; // 编号为0的点表示空节点,设为-INF可以防止更新产生错误
w[0] = w[n + 1] = -INF; // 哨兵
for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
root = build(0, n + 1, 0);
char op[20];
while (m--) {
// 输出当前序列
// cout << "当前序列为: ";
// output(root);
// cout << endl;
scanf("%s", op);
if (!strcmp(op, "INSERT")) {
int posi, tot;
scanf("%d%d", &posi, &tot);
for (int i = 0; i < tot; i++) scanf("%d", &w[i]);
int l = get_k(posi + 1), r = get_k(posi + 2);
splay(l, 0), splay(r, l);
int u = build(0, tot - 1, r);
tr[r].s[0] = u;
pushup(r), pushup(l);
} else if (!strcmp(op, "DELETE")) {
int posi, tot;
scanf("%d%d", &posi, &tot);
int l = get_k(posi), r = get_k(posi + tot + 1);
splay(l, 0), splay(r, l);
dfs(tr[r].s[0]); // 回收被删除的节点编号
tr[r].s[0] = 0; // 删除r的左孩子
pushup(r), pushup(l);
} else if (!strcmp(op, "MAKE-SAME")) {
int posi, tot, c;
scanf("%d%d%d", &posi, &tot, &c);
int l = get_k(posi), r = get_k(posi + tot + 1);
splay(l, 0), splay(r, l);
auto &son = tr[tr[r].s[0]]; // son这棵子树值全部变为c
son.same = 1, son.v = c, son.sum = c * son.size;
if (c > 0) son.ms = son.ls = son.rs = son.sum;
else son.ms = c, son.ls = son.rs = 0;
pushup(r), pushup(l);
} else if (!strcmp(op, "REVERSE")) {
int posi, tot;
scanf("%d%d", &posi, &tot);
int l = get_k(posi), r = get_k(posi + tot + 1);
splay(l, 0), splay(r, l);
auto &son = tr[tr[r].s[0]];
son.rev ^= 1;
swap(son.ls, son.rs);
swap(son.s[0], son.s[1]);
pushup(r), pushup(l);
} else if (!strcmp(op, "GET-SUM")) {
int posi, tot;
scanf("%d%d", &posi, &tot);
int l = get_k(posi), r = get_k(posi + tot + 1);
splay(l, 0), splay(r, l);
printf("%d\n", tr[tr[r].s[0]].sum);
} else {
printf("%d\n", tr[root].ms);
}
}
return 0;
}