1 洛谷P3521 ROT-Tree Rotations
2 题目描述
时间限制 \(1s\) | 空间限制 \(128M\)
给定一棵有 \(n\) 个叶节点的二叉树。每个叶节点都有一个权值 \(p_i\)(注意,根不是叶节点),所有叶节点的权值构成了一个 \(1∼n\) 的排列。
对于这棵二叉树的任何一个结点,保证其要么是叶节点,要么左右两个孩子都存在。
现在你可以任选一些节点,交换这些节点的左右子树。
在最终的树上,按照先序遍历遍历整棵树并依次写下遇到的叶结点的权值构成一个长度为
\(n\) 的排列,你需要最小化这个排列的逆序对数。输出一行一个整数表示最小的逆序对数。
数据范围:
- 对于 \(30\%\) 的数据,保证 \(n \leq 5 \times 10^3\);
- 对于 \(100\%\) 的数据,保证 \(2 \leq n \leq 2 \times 10^5, 0 \leq x \leq n\) ,所有叶节点的权值是一个 \(1 ∼ n\) 的排列。
3 题解
观察题目,我们可以考虑一个贪心策略:对于每一个节点,我们只需要算出交换其左右子树和不交换其左右子树的情况下,左右子树之间分别产生的逆序对个数即可。
证明可以考虑分治的思想:左右子树内部的逆序对个数的最小值已经计算完毕了,我们现在只需要算出左右子树之间产生的逆序对个数的最小值,就可以得到当前节点为根的子树内部的逆序对个数的最小值了。
朴素的想法是,对于每一个节点,我们直接遍历其左子树所代表的所有数。对于每一个数,求出在右子树中有多少个数小于这个数。这些个数的和就是不交换时的逆序对个数。交换时的逆序对个数同理。最终把这两个逆序对个数取个 \(min\),然后直接加上左右子树的最小逆序对个数之和即可。
容易发现,这个算法的时间复杂度是 \(O(n^3)\) 的:最多存在 \(n\) 个需要遍历的节点,对于每个节点我们需要枚举其左子树和右子树里的每个数,可以近似看为 \(n\),且对于每个左子树和右子树里的数都需要遍历一遍另外一棵子树,这个也可以近似看成 \(n\)。乘在一起的时间复杂度就是 \(O(n^3)\)。
显然,我们可以首先优化第三个 \(n\):对于每一棵子树建立一棵权值线段树,维护每一个权值的出现次数。然后第三个 \(n\) 就可以转化为 \(log_2 n\) 了。此时的时间复杂度是 \(O(n^2 log_2n)\)。注意:我们此时需要用线段树合并来将两个子树的信息合并到当前节点上去。
这个时间复杂度不够优秀,所以我们继续考虑如何优化:我们可以借助线段树合并来计算逆序对个数。具体地,每当我们在线段树中需要 \(return\) 的时候,我们都判断一下当前区间为左子树区间还是右子树区间,即在线段树合并后,该区间全部为左子树还是右子树上的数。(线段树合并只有在两棵线段树中某一棵线段树在当前区间不存在任何节点时才会退出)。然后我们将当前区间的权值的个数和乘上另一棵子树的大于这一区间的权值个数,就可以得到交换或者不交换的逆序对个数(若当前区间是左子树区间,则得到的是交换后的逆序对个数,否则是不交换时的逆序对个数)。此时的时间复杂度就相当于在线段树合并的基础上再加上一个 \(log_2n\) 的时间复杂度,但是我们就不用将子树内的每一个数都枚举一遍,时间复杂度就是 \(O(nlog_2^2n)\)。
其实这个时候已经可以通过了,但是考虑到 \(loj\) 把时间限制开到了 \(160ms\),我们继续优化。注意到,我们在线段树合并的时候一定遍历过我们用 \(log_2n\) 时间复杂度去查询的区间,我们可以借助这一信息直接对逆序对个数进行更改。当我们递归到某一区间时,当前区间的左子区间里所有的数一定小于当前区间的右子区间里所有的数。此时这个区间对交换前的逆序对个数的贡献就是左子树在右子区间里权值的个数乘以右子树左子区间里权值的个数。对交换后的逆序对个数贡献同理。根据线段树合并的特性,我们一定会把所有的可以提供贡献的区间遍历到,且对于每个区间只需要 \(O(1)\) 的时间复杂度计算出其对两种逆序对个数的贡献。此时我们的时间复杂度就是完美的 \(O(n log_2n)\),可以通过此题(虽然在 \(loj\) 上还是需要卡卡常)。
4 代码:
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 4e5+10;
typedef long long ll;
int n, x, tot, cnt;
ll cnt1, cnt2, ans; // cnt1 : 交换后的最小逆序对个数,cnt2 : 不交换时的最小逆序对个数
struct node
{
int lc, rc;
int sum; // 权值出现次数
}t[N*40];
int rt[N];
int read()
{
int x = 0;
bool f = 0;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-') f = 1;
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return f ? -x : x;
}
int build()
{
tot++;
t[tot].lc = t[tot].rc = t[tot].sum = 0;
return tot;
}
void pushup(int p) {t[p].sum = t[t[p].lc].sum + t[t[p].rc].sum;}
void modify(int p, int l, int r, int pos)
{
if (l == r)
{
t[p].sum++;
return ;
}
int mid = (l + r) >> 1;
if (pos <= mid)
{
if (!t[p].lc) t[p].lc = build();
modify(t[p].lc, l, mid, pos);
}
else
{
if (!t[p].rc) t[p].rc = build();
modify(t[p].rc, mid+1, r, pos);
}
pushup(p);
}
int merge(int p, int p2, int l, int r)
{
if (!p) return p2;
if (!p2) return p;
if (l == r)
{
t[p2].sum += t[p].sum;
return p2;
}
int mid = (l + r) >> 1;
cnt1 += (ll)t[t[p].lc].sum * (ll)t[t[p2].rc].sum;
cnt2 += (ll)t[t[p].rc].sum * (ll)t[t[p2].lc].sum;
t[p2].lc = merge(t[p].lc, t[p2].lc, l, mid);
t[p2].rc = merge(t[p].rc, t[p2].rc, mid+1, r);
pushup(p2);
return p2;
}
int dfs()
{
cnt++;
int x = cnt;
rt[x] = build();
int val = read();
if (!val)
{
int lc = dfs(), rc = dfs();
cnt1 = cnt2 = 0;
rt[x] = merge(rt[lc], rt[rc], 1, n);
ans += min(cnt1, cnt2);
}
else modify(rt[x], 1, n, val);
return x;
}
int main()
{
n = read();
dfs();
printf("%lld", ans);
return 0;
}