基本概念
-
路径:在一棵树中,一个结点到另一个结点之间的通路,称为路径。下图中,从根结点到结点 a 之间的通路就是一条路径。
-
路径长度:在一条路径中,每经过一个结点,路径长度都要加 1 。例如在一棵树中,规定根结点所在层数为1层,那么从根结点到第 i 层结点的路径长度为 i - 1 。下图从根结点到结点 c 的路径长度为 3。
-
结点的权:给每一个结点赋予一个新的数值,被称为这个结点的权。下图中结点a的权是1。
-
结点的带权路径长度:指的是从根结点到该结点之间的路径长度与该结点的权的乘积。下图中结点 b 的带权路径长度为 b * 3 。
-
树的带权路径长度:树中所有叶子结点的带权路径长度之和。通常记作 “WPL” 。下图中所示的这颗树的带权路径长度为:
W P L = a ∗ 3 + b ∗ 3 + c ∗ 3 + d ∗ 3 + e ∗ 2 + f ∗ 2 = 24 WPL = a * 3 + b * 3 + c * 3 + d * 3 + e * 2 + f * 2 = 24 WPL=a∗3+b∗3+c∗3+d∗3+e∗2+f∗2=24
带权路径长度的计算也可以看作结点合并:
1、合并a和b为g、c和d为h,代价: a + b + c + d = 4 a+b+c+d = 4 a+b+c+d=4
2、合并g和h为i,代价: g + h = a + b + c + d = 4 g+h=a+b+c+d = 4 g+h=a+b+c+d=4
3、合并e和f为j,代价: e + f = 6 e + f = 6 e+f=6
4、合并i和j为k,代价: i + j = a + b + c + d + e + f = 10 i + j = a + b + c + d + e + f = 10 i+j=a+b+c+d+e+f=10
总共代价为24。 -
哈夫曼树:当用 n 个结点(都做叶子结点且都有各自的权值)试图构建一棵树时,如果构建的这棵树的带权路径长度最小,称这棵树为“最优二叉树 / 赫夫曼树 / 哈夫曼树”。
算法
如何构建huffman树呢?
只需要遵循一个原则:权重越大的结点离树根越近
具体做法:
1、先从所有结点中挑选两个最小的结点,合并成新的结点
2、去掉合并前的这两个结点,加入新的结点
3、重复1和2步直到所有结点合并
以上图为例:
- a=1、b=1、c=1、d=1、e=3、f=3,最小的结点是a和b。
合并a、b,新的结点g=2。 - c=1、d=1、e=3、f=3、g=2,最小的结点是c和d
合并c、d,新的结点h=2 - e=3、f=3、g=2、h=2,最小的结点是g和h
合并g、h,新的结点i=4 - e=3、f=3、i=4,最小的结点是e和f
合并e、f,新的结点j=6 - i=4、j=6,最小的结点是i和j
合并i、j,新的结点k=10,所有结点合并,停止。
每次求最小的两个结点可以用小根堆。
正确性
证明1:在huffman树中,权值最小的两个点深度一定最深,并且可以互为兄弟。
假设某个点权值是最小的,并且深度不是最深。
那么将它和最深的点交换,树的带权路径长度一定会变小。
同样以下图为例,假如交换了f和b,那么这两点的带权路径长度从
b
∗
2
+
f
∗
3
=
1
∗
2
+
3
∗
3
=
11
b*2+f*3=1*2+3*3=11
b∗2+f∗3=1∗2+3∗3=11变成了
b
∗
3
+
f
∗
2
=
1
∗
3
+
3
∗
2
=
9
b*3+f*2=1*3+3*2=9
b∗3+f∗2=1∗3+3∗2=9。
因此权值最小的点都会被交换到最底下(深度最深)。
显然abcd互相交换,深度不变,树的带权路径长度仍是最小。
由证明1可知,最优解一定要先合并最小的两个点,因为二者深度一定最深且在同一层。
证明2:每次从所有所有结点中挑选最小的两个结点能保证全局最优(整棵树树的带权路径长度最小)
合并完最小的两个结点后,新的结点和剩下的n-2个结点,总共有n-1个结点。
由证明1我们知道,不管这n-1个结点的合并方案是什么,n个结点合并的最优解一定是先合并最小的两个点。
因此我们只要求剩下n-1个结点合并的最小代价,再加上最小的两个结点合并的代价,就可以求出全局最优解。
而剩下n-1个结点可以重复上述过程,再找n-1个结点中的最小的两个结点合并。
代码实现 O(nlogn)
这题消耗的体力值就是Huffman树的带权路径长度。
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
int n, ans;
priority_queue<int, vector<int>, greater<int> > heap;
int main() {
scanf("%d", &n);
int t;
for (int i = 0; i < n; i ++ ) {
scanf("%d", &t);
heap.push(t);
}
while (heap.size() > 1) {
int a = heap.top();
heap.pop();
int b = heap.top();
heap.pop();
heap.push(a + b);
ans += a + b;
}
printf("%d\n", ans);
return 0;
}
桶排\基数排序+双队列 O(n)
题目链接
这题的n扩大到了
1
0
7
10^7
107,因此用O(nlogn)会超时。
因为每次挑的两个最小结点都越来越大,因此合并完的结点也是越来越大。
先将所有结点排序存在队列a中,从a挑最小的两个结点合并。
令开一个新队列b存合并完的结点
下次挑最小的两个结点时,因为a和b都是有序的,可以直接从a和b的队列的队头挑一个最小的结点,挑选两次。
这样就将挑选最小两个结点的时间降到O(1)
由于快排是nlogn的,因为这题数字范围很小( 1 ≤ a i ≤ 1 0 5 1≤a_i≤10^5 1≤ai≤105),要换成O(n)的桶排。如果数字范围较大,可以换成接近 O ( ( n + b ) l o g b n ) O((n+b)log_{b}n) O((n+b)logbn)的基数排序。
注意要用long long和手写读入,scanf会超时。
桶排序
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
int n;
long long ans;
int bucket[100005];
queue<long long> a, b;
inline int read(){
int ret = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0', ch = getchar();
return ret * f;
}
int main() {
scanf("%d", &n);
long long t, max_t = 0;
for (int i = 0; i < n; i ++ ) {
// scanf("%d", &t);
t = read();
bucket[t] ++ ;
max_t = max(max_t, t);
}
for (int i = 0; i <= max_t; i ++ ) {
while(bucket[i] -- ) a.push(i);
}
for (int i = 1; i < n; i ++ ) {
long long t1, t2;
if (b.empty() || (!a.empty() && a.front() <= b.front() )) {
t1 = a.front();
a.pop();
}
else {
t1 = b.front();
b.pop();
}
if (b.empty() || (!a.empty() && a.front() <= b.front() )) {
t2 = a.front();
a.pop();
}
else {
t2 = b.front();
b.pop();
}
b.push(t1 + t2);
ans += t1 + t2;
}
printf("%lld\n", ans);
return 0;
}
基数排序
#include <iostream>
#include <algorithm>
#include <queue>
#include <cstring>
using namespace std;
int n;
long long ans;
int x[10000005], y[10000005], cnt[260];
queue<long long> a, b;
inline int read() {
int ret = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0', ch = getchar();
return ret * f;
}
int main() {
scanf("%d", &n);
long long t, max_t = 0;
for (int i = 0; i < n; i ++ ) x[i] = read();
for (int i = 0; i <= 17; i += 8) {
memset(cnt, 0, sizeof cnt);
for (int j = 0; j < n; j ++ ) cnt[(x[j] >> i) & 255] ++ ;
for (int j = 1; j < 256; j ++ ) cnt[j] += cnt[j - 1];
for (int j = n - 1; j >= 0; j -- ) { // 从后往前扫a,同时cnt从高往低减,b从后往前放,等价于正着做
y[ -- cnt[(x[j] >> i) & 255]] = x[j];
}
for (int j = 0; j < n; j ++ ) x[j] = y[j];
}
for (int i = 0; i < n; i ++ ) a.push(x[i]);
for (int i = 1; i < n; i ++ ) {
long long t1, t2;
if (b.empty() || (!a.empty() && a.front() <= b.front() )) {
t1 = a.front();
a.pop();
}
else {
t1 = b.front();
b.pop();
}
if (b.empty() || (!a.empty() && a.front() <= b.front() )) {
t2 = a.front();
a.pop();
}
else {
t2 = b.front();
b.pop();
}
b.push(t1 + t2);
ans += t1 + t2;
}
printf("%lld\n", ans);
return 0;
}