字典树

字典树

算法思想

字典树(\(Trie\))是一个比较简单的数据结构,也叫前缀树或 \(Trie\) 树,用来存储和查询字符串。

例如,waterwishwintietired 这几个单词可以用以下方式存储 :

字典树

此时每一个叶子结点递归往上到根节点都是对应一个字符串。

其中每个字符占据一个节点,拥有相同前缀的字符串可以共用部分节点。

起始点是特殊点(我们设为 \(1\) 号点),不存储字符。

建树的代码如下:

const int MAXN = 500005;
int nxt[MAXN][26], cnt; // nxt[i][c] 表示 i 号点所连、存储字符为 c + 'a' 的点的编号
void init()              // 初始化
{
    memset(nxt, 0, sizeof(nxt));
    cnt = 1;
}
void insert(const string &s) // 插入字符串
{
    int cur = 1;
    for (auto c : s)
    {
        // 尽可能重用之前的路径,如果做不到则新建节点
        if (!nxt[cur][c - 'a'])
            nxt[cur][c - 'a'] = ++cnt;
        cur = nxt[cur][c - 'a']; // 继续向下
    }
}

查询代码如下 :

bool Find(const string &s) // 查找某个前缀是否出现过
{
    int cur = 1;
    for (auto c : s)
    {
        // 沿着前缀所决定的路径往下走,如果中途发现某个节点不存在,说明前缀不存在
        if (!nxt[cur][c - 'a'])
            return false;
        cur = nxt[cur][c - 'a'];
    }
    return true;
}

如果是查询某个字符串是否存在,可以另开一个 \(vis\) 数组,在插入完成时,把 \(vis[叶子节点]\) 设置为 \(true\),然后先按查询前缀的方法查询,在结尾处再判断一下 \(vis\) 的值。

这是一种常见的套路,即用叶子节点代表整个字符串,保存某些信息。

字典树是一种比较典型的空间换时间的数据结构,我们牺牲了字符串个数 \(\times\) 字符串平均字符数 \(\times\) 字符集大小的空间,但可以用 \(O(n)\) 的时间查询,其中 \(n\) 为查询的前缀或字符串的长度。

代码实现

例题

给出 \(n\) 个字符串和 \(m\) 个查询,每次查询给出一个字符串 \(s\)。如果 \(s\) 没有被给出,输出 WRONG,如果 \(s\) 被给出并且是第一次被查询,输出 OK,如果 \(s\) 被给出并且不是第一次被查询,输出 REPEAT

解题思路

维护每一个字符串出现的次数即可,记录在叶子结点上。

AC CODE

#include <bits/stdc++.h>
using namespace std;

const int _ = 3e5 + 5;

int n, m, cnt;
int nxt[_][26];
bool vis[_];
char s[60];

void insert(string s)
{
    int t = 0, len = s.size();
    for (int i = 0; i < len; i++)
    {
        int c = s[i] - 'a';
        if (!nxt[t][c])
            nxt[t][c] = ++cnt;
        t = nxt[t][c];
    }
}

int search(string s)
{
    int t = 0, len = s.size();
    for (int i = 0; i < len; i++)
    {
        int c = s[i] - 'a';
        if (!nxt[t][c])
            return 0;
        t = nxt[t][c];
    }
    if (!vis[t])
    {
        vis[t] = true;
        return 1;
    }
    return 2;
}

signed main()
{
    int res;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
    {
        scanf("%s", s);
        insert(s);
    }
    scanf("%d", &m);
    for (int i = 1; i <= m; i++)
    {
        scanf("%s", s);
        res = search(s);
        if (!res)
            puts("WRONG");
        else if (res == 1)
            puts("OK");
        else
            puts("REPEAT");
    }
    return 0;
}

扩展

例题

在给定的 \(n\) 个整数 \(a_1,a_2,\sim,a_n\) 中选出两个进行异或运算,得到的结果最大是多少?

解题思路

先引入数据结构,\(01 \ Trie\) 是 \(Trie\) 树衍生出的一种数据结构,它可以用来维护与 \(\color{red}{异或}\) 相关的题目。

\(01 \ Trie\) 十分简单,我们把若干个数转换成二进制表示,也就是若干个 \(01\) 串,然后把这些 \(01\) 串插入 \(Trie\) 树,得到的 \(Trie\) 树称之为 \(01 \ Trie\)。

异或 :将两个数都变成二进制,只有在两个比较的位不同时其结果是 \(1\),否则结果为 \(0\)。

(x >> i) & 1 判断十进制数 \(x\) 变成二进制后的第 \(i\) 位是多少。

我们可以发现,想令异或和最大,那么我们应该尽量保证高位的数不相同。

因为假设我们令从左往右数第 \(x\) 位不同,但我们不取第 \(x\) 位,那么即使后面所有位上的数字都不相同,损失最多为 \(\sum\limits_{i=0}^{x-1}2^i<2^x\)。

如果我们把第 \(x\) 位取相反的数,后面可能的贡献总和也一定比不取第 \(x\) 位的损失要小。

于是我们可以想出一个贪心策略。对于给定的数 \(x\),我们可以二进制从高位到低位遍历。

假设当前遍历到了从左往右的第 \(i\) 位,若深度为 \(i+1\) 的结点中存在权值与 \(x\) 的第 \(i\) 位相反的结点,我们递归进入该结点;

反之进入另外一个结点递归。

AC CODE

#include <bits/stdc++.h>
using namespace std;

int tr[3500005][2], n, cnt = 1, ans;

inline int read()
{
    static char ch;
    static int n;
    n = 0;
    while (ch < '0' || ch > '9')
        ch = getchar();
    while (ch >= '0' && ch <= '9')
        n = n * 10 + ch - '0', ch = getchar();
    return n;
}

inline void insert(int v)
{
    int u = 1;
    for (int i = 31; i >= 0; --i)
    {
        int k = (v >> i) & 1;
        if (!tr[u][k])
            tr[u][k] = ++cnt;
        u = tr[u][k];
    }
}

inline int query(int x)
{
    int u = 1, ans = 0;
    for (int i = 31; i >= 0; --i)
    {
        int c = (x >> i) & 1;
        if (tr[u][c ^ 1])
        {
            u = tr[u][c ^ 1];
            ans += (int)(1 << i);
        }
        else
            u = tr[u][c];
    }
    return ans;
}

signed main()
{
    n = read();
    for (int i = 1; i <= n; ++i)
    {
        int a = read();
        insert(a);
        ans = max(ans, query(a));
    }
    printf("%d", ans);
    return 0;
}
上一篇:自动机(估计要写几天)


下一篇:前缀树(Trie)两种方式实现详解--C++数据结构的实现