比较巧妙的构造题,但是重点实际上在bitset优化和输出方案上。
题意大概是:
定义节点的分值为子树中叶子节点的个数。现在给定所有节点(只有中间节点,不包含叶子节点)的分值,现在要求构造一棵森林,当用一个虚根把所有森林连起来之后,虚根的权值为$s$。
要求输出方案,每个节点输出他的直接儿子和除开这些儿子它挂了几个叶子。
其实我们只需要考虑最上层的一些,最上层的一些加起来刚好等于$s$,就说明可以组成$s$,否则的话肯定不行。
然后权值最大的那个节点一定是上层节点,因为它不可能作为任何其他树的子节点。
然后我们发现其实问题就解决了,因为一个树的分值只跟它最上面的节点分值有关,下面不管怎么连都没事,所以我们只需要把剩下的排个序,一个套一个,全部套在最大的那个节点底下。
然后重点来了:怎么判断是否存在一系列点加起来刚好等于$s$。
(背包居然是NP完全问题)
一个简单的背包思路$f_{i,j}$表示前$i$个物品,能否组成$j$的体积。
转移就是$f_{i, j} = f_{i - 1,j} || f_{i - 1, j - a[i]}$。
时间复杂度为$O(ns)$,爆炸了。
然后发现这个$f$数组为$01$数组,可以用bitset优化,可以除掉$32$,成功通过。
不过这样仅仅是判断,如何输出方案呢。
以下来自于另一博客的方法。
我们可以分段保存中间结果。
每隔10个,保存以下bitset数组,这样我们就间隔10保存了结果。
倒着推回去,假如$f_{i,j}$可行,并且用到了$a[i]$,就必存在一个$f_{i - 1,j - a[i]}$可行,我们从10步之前的那个结果再往后推一下,就能推到$f_{i - 1}$的结果了,就能判断是否可行。
有了方案之后还原一下就行了。
#include <bits/stdc++.h> #define Mid ((l + r) / 2) #define lson (rt << 1) #define rson (rt << 1 | 1) using namespace std; int read() { char c; int num, f = 1; while(c = getchar(),!isdigit(c)) if(c == ‘-‘) f = -1; num = c - ‘0‘; while(c = getchar(), isdigit(c)) num = num * 10 + c - ‘0‘; return f * num; } const int N = 7e4 + 1; struct node { int id, val, in, nxt; } a[N]; int n, s, nxt[N]; bitset<N> tmpx[N / 10]; bitset<N> tmp, tmp2; vector<node> b; int ok(int x, int y) { if(y < 0) return 0; tmp2 = tmpx[x / 10]; for(int i = x / 10 * 10 + 1; i <= x; i++) { tmp2 |= (tmp2 << a[i].val); } return tmp2[y]; } int cmp(node a, node b) {return a.val < b.val;} int cmp2(node a, node b) {return a.id < b.id;} signed main() { n = read(); s = read(); for(int i = 1; i <= n; i++) a[i].val = read(); for(int i = 1; i <= n; i++) a[i].id = i; sort(a + 1, a + 1 + n, cmp); tmp[0] = 1; tmpx[0] = tmp; for(int i = 1; i < n; i++) { tmp |= (tmp << a[i].val); if(i % 10 == 0) tmpx[i / 10] = tmp; } int tar = s - a[n].val; memset(nxt, -1, sizeof(nxt)); if(ok(n - 1, tar)) { int now = s - a[n].val; for(int i = n - 1; i; i--) { if(now >= a[i].val && ok(i - 1, now - a[i].val)) { now -= a[i].val; a[i].in = 1; } } int pri = n; for(int i = n - 1; i; i--) if(!a[i].in) { a[pri].nxt = a[i].id; pri = i; } sort(a + 1, a + 1 + n, cmp2); for(int i = 1; i <= n; i++) { if(a[i].nxt) { printf("%d 1 %d\n", a[i].val - a[a[i].nxt].val, a[i].nxt); } else { printf("%d 0\n", a[i].val); } } } else { printf("-1"); } return 0; }