你也可以通过我的独立博客 —— www.huliujia.com 获取本篇文章
数组与链表
数组和链表是非常基础的两种数据结构,链表每次查找都需要从头结点开始线性遍历,时间复杂度是O(n)。而数组通过维护元素顺序可以使用二分查找,查找的时间复杂度是O(lg(n))。查找效率方面数组完胜链表。
但是由于数组插入或删除元素时必须要移动所有受影响的节点,所以时间复杂度是O(n).并且数组的长度是固定的,当空间不够时需要重新分配内存。而链表插入和删除元素只需要改动少量指针即可,时间复杂度是O(1)。但是因为插入和删除一半都需要先查找元素,所以实际上链表的插入和删除时间复杂度还是O(n)。
今天介绍一种基于链表的数据结构 —— 跳跃表(Skiplist),在保持链表便于插入、删除的特性同时,可以把查找、插入、删除的时间复杂度降到O(lg(n))。
跳跃表的原理
数组能够实现查找时间复杂度O(lg(n)),主要是因为二分查找每次都可以排除一半的元素,那么链表有没有办法也每次排除一半的元素呢?
显然,原生的链表每次只能排除一个元素(当前元素),想要排除一半的元素需要满足两个条件,首先是链表必须是有序的,其次是能够访问到最中间的元素,这个其实就是数组二分查找的原理。
让链表有序是一个比较容易实现的需求,那么如何访问到最中间的元素呢?我们可以使用一个外部节点来保存中间节点。
比如上图是一个有5个元素的有序链表,红色的元素3指向了中间的元素3,如果我们想要查找4,通过和中间元素3进行比较,很容易判断要查找的元素在3的右侧,那么就排除了3左侧的所有节点了。
上面的链表只有5个元素,那如果链表有更多的元素,比如8个元素呢?显然一个外部节点只能完成一次元素排除,如果想每次访问都能排除一半的元素,即需要更多的外部节点了。
比如上图链表中有8个元素,总共使用了3个外部节点,外部节点分成了两层,通过和最上层外部节点元素5比较,我们可以排除掉一半的元素,剩下的元素和元素3或者元素7比较,可以再排除掉一半的元素。
类似的,如果元素更多的话,可以使用更多的外部节点,用更多的层来管理外部节点。
可以看到,通过添加外部节点,我们可以实现每次排除一半的元素,那么和二分查找类似,最终可以实现查找时间复杂度为O(lg(n))。
跳跃表的实现
上面使用外部节点作为例子来演示了跳跃表查找加速的原理,但是如果直接按照上面的方式来实现的话,外部节点的管理和查找操作都很复杂,难以实现。
所以实际实现时,会把每一层的外部节点也使用有序链表的方式来管理,并且给每个链表添加一个假的头结点和尾节点,方便查找和判断是否到达了链表尾部。
最后我们的跳跃表应该是这样的
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eRW18yAr-1618188079241)(/images/跳跃表-图3.png)]
可以看到,每层的节点都变成了一个链表,越往上层,链表的元素个数越少。每层的链表都包含头结点Hx和尾节点Tx。
那么多少层链表的效率最高呢?这个取决于最下层链表的元素数量,为了实现时间复杂度O(lg(n))。每层链表的元素数量(不含头尾节点)应当是其下层元素数量的二分之一,并且最上层链表的元素数量为1。那么很容易计算出最佳的层数是 lg(N),N是链表最底层的元素数量,也就是链表的实际元素数量。比如有8个节点,那么层数应当是lg(8) = 3;
查找元素
查找时,以最高层的头结点H2节点为起点,比较搜索的元素X和右侧元素Y的大小,根据判断结果决定是返回Y (X == Y),还是在Y的左侧继续搜索(X < Y),或者是在Y的右侧搜索(X > Y)。重复这个过程,直到找到X或者确定X不存在。
假设X=4,那么第一步比较X和H2的右侧元素5,X < 5,所以在5的左侧查找,切换到下一层,起点由H2变为H1,比较X和H1的右侧元素3,X > 3,所以应当在3的右侧查找,切换到下一层,起点由H1变为元素3,比较X和元素3的右侧元素5。X < 5,所以在元素5的左侧查找,起点切换到第0层的元素3(刚才是第1层的元素3),比较X和3的右侧元素4,X==4,找到了X。
如果 X = 4.5 呢?显然前面的过程和X=4时是一样的,但是当起点切换到第0层的元素3时,比较X和3右侧的元素4,发现 X > 4,由于此时已经在第0层了,无法继续往下搜索了,所以可以判断元素 X = 4.5 是不存在的。
插入元素
插入元素时,先确定插入位置,然后修改前后的指针即可。但是这样存在一个问题,当插入元素越来越多时,上层的链表显然无法再实现均匀地二分了,那么查找的效率就会降低。
是否可以在每次插入的时候重新组织上层链表,让上层链表元素可以均匀地对下层链表进行二分呢?答案显然也是不行的,因为这样每次插入的时间复杂度就变成O(n)了。跳跃表也常被称为是基于概率的数据结构,所谓基于概率就是在这里涉及到的。
简单来说就是通过概率的方式,让每层链表元素的数量是其下层元素数量的1/2,只要让每个第N层的节点有1/2的概率进入第N+1层即可。当元素数量足够多时,第N+1层的链表元素可以近似均匀地分割第N层的元素了。
我们在插入一个新的元素时,首先把元素插入到最低层,即第0层,然后让其有二分之一的概率(比如通过随机数)进入第1层,如果进入了第1层,再让其有1/2的概率进入第2层,依次类推。如果当前层已经是允许的最高层,就在当前层停下。
实际实现时,是先基于概率计算这个元素最高到哪一层,比如元素最高到M层,就把元素插入到第0层到第M层之间(含第M层)的每一层。
删除元素
删除元素比较简单,查找到元素后,把该元素在每一层的节点都删除即可。因为每个元素在第N层的概率都是相同的,所以删除元素时不会影响到多层链表的分布概率。
代码实现
实际代码实现时,为了提高性能并减少存储开销,并不会真的创建多层链表,每个元素实际只会有一个节点,节点除了保存元素的值,还有一个指针数组next,用于保存元素在不同层的链表关系。next[i]指向元素在第i层的右侧元素。如果next[i]为空,表示第i层没有当前元素。如果next[i]指向了元素A,表示当前元素在第i层的右侧元素是A。具体代码实现如下:
#include <stdio.h>
#include <cstring>
#include <random>
#include <vector>
const int kMaxLevelNum = 12;
struct Node
{
int value_;
Node *next_[kMaxLevelNum];
};
class Skiplist
{
public:
Skiplist()
{
//每层都创建头尾节点
head_ = new Node();
head_->value_ = -1;
auto *tail = new Node();
tail->value_ = 1;
memset(tail->next_, 0, kMaxLevelNum * sizeof(void *));
for (int i = 0; i < kMaxLevelNum; i++)
{
head_->next_[i] = tail;
}
}
bool search(int target)
{
Node *prev[kMaxLevelNum];
//FindEqual会查找目标节点,并保存每层的前置节点
return FindEqual(target, prev) != nullptr;
}
void add(int num)
{
Node *prev[kMaxLevelNum];
//FindGreaterOrEqual会找到大于等于目标值的最小节点,并保存每层的前置节点
Node *ge_node = FindGreaterOrEqual(num, prev);
Node *node = new Node();
node->value_ = num;
int height = RandomHeight();
for (int i = 0; i < height; i++)
{
node->next_[i] = prev[i]->next_[i];
prev[i]->next_[i] = node;
}
}
bool erase(int num)
{
Node *prev[kMaxLevelNum];
//先查找再删除
auto *ge_node = FindEqual(num, prev);
if (ge_node == nullptr)
{
return false;
}
for (int i = 0; i < kMaxLevelNum; i++)
{
if (ge_node->next_[i] == nullptr)
{
break;
}
prev[i]->next_[i] = ge_node->next_[i];
ge_node->next_[i] = prev[i];
}
delete ge_node;
return true;
}
//分层当前的跳跃表,这个主要是为了让输出可视化,不涉及到跳跃表的插入、删除、查找。
void PrintAll()
{
auto* cur_node = head_->next_[0];
int counter = 1;
std::vector<Node*> node_list;
std::string debug_str[kMaxLevelNum];
for(int i=0; i<kMaxLevelNum; i++)
{
debug_str[i] += "|";
}
printf("%p, ",head_);
while (cur_node->next_[0] != nullptr)
{
char buf[100];
sprintf(buf, "%-2d", cur_node->value_);
debug_str[0] += " -> ";
debug_str[0] += buf;
node_list.emplace_back(cur_node);
printf("%p, ",cur_node);
cur_node = cur_node->next_[0];
}
printf("%p, ",cur_node);
printf("\n");
debug_str[0] += "-> |";
cur_node = head_;
for(int i=1; i<kMaxLevelNum; i++)
{
debug_str[i] = "|";
auto* cur_left_node = head_;
for(auto* node: node_list)
{
printf("node: %p, left: %p, left->next[i]: %p\n", node, cur_left_node, cur_left_node->next_[i]);
printf("node: %d, left: %d\n", node->value_, cur_left_node->next_[i]->value_);
if(node == cur_left_node->next_[i])
{
printf("true\n");
char buf[100];
sprintf(buf, "%-2d", node->value_);
debug_str[i] += " -> ";
debug_str[i] += buf;
cur_left_node = node;
}else
{
printf("false\n");
debug_str[i] += "------";
}
}
debug_str[i] += "-> |";
}
for(int i=kMaxLevelNum-1; i>=0; i--)
{
printf("level: %d: %s\n", i, debug_str[i].c_str());
}
printf("\n");
}
private:
Node *head_;
int cur_max_level_ = 0;
int RandomHeight()
{
static std::mt19937 mt_rand(0);
int height = 1;
while (height < kMaxLevelNum and mt_rand() & 0x1)
{
height++;
}
return height;
}
Node *FindEqual(int target, Node **prev)
{
Node *ge_node = FindGreaterOrEqual(target, prev);
if (ge_node->next_[0] != nullptr and ge_node->value_ == target)
{
return ge_node;
}else
{
return nullptr;
}
}
//返回值应当是比key大的最小的数,prev是一个长度为kMaxHeight的数组,保存每层比key小的数的指针。最后是要利用这个指针把当前的key嵌入进去的。
Node *FindGreaterOrEqual(int target, Node **prev)
{
int level = kMaxLevelNum - 1;
Node *cur_node = head_;
while (true)
{
auto *next = cur_node->next_[level];
if (next == nullptr) // right edge
{
return cur_node;
}else if (next->next_[level] == nullptr or next->value_ >= target)
{
//right node is right edge or is larger
if (level == 0)
{
prev[level] = cur_node;
return next;
}else
{
prev[level] = cur_node;
level--;
}
}else //(next->value_ < target)
{
cur_node = next;
}
}
}
};
int main()
{
auto *obj = new Skiplist();
const char *name[3];
name[0] = "add";
name[1] = "erase";
name[2] = "search";
obj->PrintAll();
for (int i = 0; i < 30; i++)
{
int which = rand() % 3;
int num = rand() % 100;
printf("i:%d, %s(%d)\n", i, name[which], num);
switch (which)
{
case 0:
obj->add(num);
break;
case 1:
obj->erase(num);
break;
case 2:
obj->search(num);
break;
default:
break;
}
}
obj->PrintAll();
return 0;
}