[CCPC2020绵阳C] Code a Trie - Trie,贪心,LCA
Description
有一个 Trie,其可能被插入了一些串,每个节点(含根节点)上都有一个值,这些值互不相同。在 Trie 上查询一个串时如果找到了就返回这个串结束节点的值,否则返回最后到达的节点的值。给定若干个查询串及查询结果,问是否存在这样的 Trie,有解时这个 Trie 最少有几个节点。
Solution
核心想法:根据所有 query 建立字典树,那么对于同一个串,答案一定是同一个;对于每一个值计算对应串的LCA, 然后把LCA标记一下,标记这些串在LCA下面的那个点为一定不存在,然后dfs贪心计算每个子树最少的节点
首先在插入的时候,我们只插到 LCA 为止,对 LCA 打标记,并且把下面的不能用的出端封死
如果这个位置已经有边或者待会这个位置准备建边或者我们重复标记了一个点,这些情况都是无解的
记录每个点被几组值访问了,如果它是 LCA 或者它被走了两次以上,那么这个点必须选
如果不是 LCA 并且只被走了一次,这种点我们称为不定点,一个 LCA 下面挂的所有不定点都必须要选,一个普通点下面挂的所有必定点中最多有一个不选
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e6 + 5;
int cid = 0;
int ind = 1, ch[N][26], cnt[N], tag[N]; // for trie;
int f[N]; // for dfs;
void dfs(int p)
{
if (cnt[p] > 1)
f[p] = 1;
int flag = tag[p];
for (int i = 0; i < 26; i++)
{
int q = ch[p][i];
if (q == 0)
continue;
if (cnt[q] == 1)
{
f[q] = flag;
flag = 1;
}
else
dfs(q);
}
}
void solve()
{
int n;
cin >> n;
while (ind > 1)
{
for (int i = 0; i < 26; i++)
ch[ind][i] = 0;
tag[ind] = 0;
cnt[ind] = 0;
f[ind] = 0;
--ind;
}
for (int i = 0; i < 26; i++)
ch[ind][i] = 0;
tag[ind] = 0;
cnt[ind] = 0;
f[ind] = 0;
vector<pair<string, int>> src;
while (n--)
{
pair<string, int> a;
cin >> a.first >> a.second;
src.push_back(a);
}
map<int, int> mp;
for (auto [x, y] : src)
mp[y]++;
int idx = 0;
for (auto &[x, y] : mp)
y = ++idx;
for (auto &[x, y] : src)
y = mp[y];
vector<vector<string>> vec(idx + 2);
for (auto [x, y] : src)
vec[y].push_back(x);
for (auto strs : vec)
{
if (strs.size() < 1)
continue;
int pos = 0;
int p = 1;
while (true)
{
int flag = 1;
for (int i = 0; i < strs.size(); i++)
{
if (pos >= strs[i].size() || strs[i][pos] != strs[0][pos])
{
flag = 0;
break;
}
}
if (flag)
{
cnt[p]++;
if (ch[p][strs[0][pos] - 'a'] == -1)
{
++cid;
cout << "Case #" << cid << ": ";
cout << -1 << endl;
return;
}
if (ch[p][strs[0][pos] - 'a'] == 0)
{
ch[p][strs[0][pos] - 'a'] = ++ind;
}
p = ch[p][strs[0][pos] - 'a'];
++pos;
}
else
break;
}
if (tag[p])
{
++cid;
cout << "Case #" << cid << ": ";
cout << -1 << endl;
return;
}
tag[p]++;
cnt[p]++;
for (int i = 0; i < strs.size(); i++)
{
if (pos < strs[i].size())
{
if (ch[p][strs[i][pos] - 'a'] > 0)
{
++cid;
cout << "Case #" << cid << ": ";
cout << -1 << endl;
return;
}
ch[p][strs[i][pos] - 'a'] = -1;
}
}
}
cnt[1]++;
dfs(1);
int ans = 0;
for (int i = 1; i <= ind; i++)
ans += f[i];
++cid;
cout << "Case #" << cid << ": ";
cout << ans << endl;
}
signed main()
{
ios::sync_with_stdio(false);
int t;
cin >> t;
while (t--)
solve();
}