字典树
算法思想
字典树(\(Trie\))是一个比较简单的数据结构,也叫前缀树或 \(Trie\) 树,用来存储和查询字符串。
例如,water
,wish
,win
,tie
,tired
这几个单词可以用以下方式存储 :
此时每一个叶子结点递归往上到根节点都是对应一个字符串。
其中每个字符占据一个节点,拥有相同前缀的字符串可以共用部分节点。
起始点是特殊点(我们设为 \(1\) 号点),不存储字符。
建树的代码如下:
const int MAXN = 500005;
int nxt[MAXN][26], cnt; // nxt[i][c] 表示 i 号点所连、存储字符为 c + 'a' 的点的编号
void init() // 初始化
{
memset(nxt, 0, sizeof(nxt));
cnt = 1;
}
void insert(const string &s) // 插入字符串
{
int cur = 1;
for (auto c : s)
{
// 尽可能重用之前的路径,如果做不到则新建节点
if (!nxt[cur][c - 'a'])
nxt[cur][c - 'a'] = ++cnt;
cur = nxt[cur][c - 'a']; // 继续向下
}
}
查询代码如下 :
bool Find(const string &s) // 查找某个前缀是否出现过
{
int cur = 1;
for (auto c : s)
{
// 沿着前缀所决定的路径往下走,如果中途发现某个节点不存在,说明前缀不存在
if (!nxt[cur][c - 'a'])
return false;
cur = nxt[cur][c - 'a'];
}
return true;
}
如果是查询某个字符串是否存在,可以另开一个 \(vis\) 数组,在插入完成时,把 \(vis[叶子节点]\) 设置为 \(true\),然后先按查询前缀的方法查询,在结尾处再判断一下 \(vis\) 的值。
这是一种常见的套路,即用叶子节点代表整个字符串,保存某些信息。
字典树是一种比较典型的空间换时间的数据结构,我们牺牲了字符串个数 \(\times\) 字符串平均字符数 \(\times\) 字符集大小的空间,但可以用 \(O(n)\) 的时间查询,其中 \(n\) 为查询的前缀或字符串的长度。
代码实现
给出 \(n\) 个字符串和 \(m\) 个查询,每次查询给出一个字符串 \(s\)。如果 \(s\) 没有被给出,输出
WRONG
,如果 \(s\) 被给出并且是第一次被查询,输出OK
,如果 \(s\) 被给出并且不是第一次被查询,输出REPEAT
。
解题思路
维护每一个字符串出现的次数即可,记录在叶子结点上。
AC CODE
#include <bits/stdc++.h>
using namespace std;
const int _ = 3e5 + 5;
int n, m, cnt;
int nxt[_][26];
bool vis[_];
char s[60];
void insert(string s)
{
int t = 0, len = s.size();
for (int i = 0; i < len; i++)
{
int c = s[i] - 'a';
if (!nxt[t][c])
nxt[t][c] = ++cnt;
t = nxt[t][c];
}
}
int search(string s)
{
int t = 0, len = s.size();
for (int i = 0; i < len; i++)
{
int c = s[i] - 'a';
if (!nxt[t][c])
return 0;
t = nxt[t][c];
}
if (!vis[t])
{
vis[t] = true;
return 1;
}
return 2;
}
signed main()
{
int res;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%s", s);
insert(s);
}
scanf("%d", &m);
for (int i = 1; i <= m; i++)
{
scanf("%s", s);
res = search(s);
if (!res)
puts("WRONG");
else if (res == 1)
puts("OK");
else
puts("REPEAT");
}
return 0;
}
扩展
在给定的 \(n\) 个整数 \(a_1,a_2,\sim,a_n\) 中选出两个进行异或运算,得到的结果最大是多少?
解题思路
先引入数据结构,\(01 \ Trie\) 是 \(Trie\) 树衍生出的一种数据结构,它可以用来维护与 \(\color{red}{异或}\) 相关的题目。
\(01 \ Trie\) 十分简单,我们把若干个数转换成二进制表示,也就是若干个 \(01\) 串,然后把这些 \(01\) 串插入 \(Trie\) 树,得到的 \(Trie\) 树称之为 \(01 \ Trie\)。
异或 :将两个数都变成二进制,只有在两个比较的位不同时其结果是 \(1\),否则结果为 \(0\)。
(x >> i) & 1
判断十进制数 \(x\) 变成二进制后的第 \(i\) 位是多少。
我们可以发现,想令异或和最大,那么我们应该尽量保证高位的数不相同。
因为假设我们令从左往右数第 \(x\) 位不同,但我们不取第 \(x\) 位,那么即使后面所有位上的数字都不相同,损失最多为 \(\sum\limits_{i=0}^{x-1}2^i<2^x\)。
如果我们把第 \(x\) 位取相反的数,后面可能的贡献总和也一定比不取第 \(x\) 位的损失要小。
于是我们可以想出一个贪心策略。对于给定的数 \(x\),我们可以二进制从高位到低位遍历。
假设当前遍历到了从左往右的第 \(i\) 位,若深度为 \(i+1\) 的结点中存在权值与 \(x\) 的第 \(i\) 位相反的结点,我们递归进入该结点;
反之进入另外一个结点递归。
AC CODE
#include <bits/stdc++.h>
using namespace std;
int tr[3500005][2], n, cnt = 1, ans;
inline int read()
{
static char ch;
static int n;
n = 0;
while (ch < '0' || ch > '9')
ch = getchar();
while (ch >= '0' && ch <= '9')
n = n * 10 + ch - '0', ch = getchar();
return n;
}
inline void insert(int v)
{
int u = 1;
for (int i = 31; i >= 0; --i)
{
int k = (v >> i) & 1;
if (!tr[u][k])
tr[u][k] = ++cnt;
u = tr[u][k];
}
}
inline int query(int x)
{
int u = 1, ans = 0;
for (int i = 31; i >= 0; --i)
{
int c = (x >> i) & 1;
if (tr[u][c ^ 1])
{
u = tr[u][c ^ 1];
ans += (int)(1 << i);
}
else
u = tr[u][c];
}
return ans;
}
signed main()
{
n = read();
for (int i = 1; i <= n; ++i)
{
int a = read();
insert(a);
ans = max(ans, query(a));
}
printf("%d", ans);
return 0;
}