题目大意:大概是这样的,一棵树n个点,每个点有点权val[i]和cost[i],给定一个m,对于每颗子树,计算出一个w值,w的计算方法为(val[i]*k),其中k为i子树下,最多能取出的使得cost的和小于等于m点的个数。
解题思路:假如,我们能将一颗子树的每个点,以cost为key,建成平衡树,那么计算答案想必还是比较简单的吧。但是我们不能给每颗子树建一棵平衡树,那么我们就从叶子节点开始计算,然后往上合并,合并两棵树时,将小的往大的里面一个个的插。
代码:
#include <stdio.h> #include <string.h> #include <algorithm> #include <map> #include <math.h> #include <queue> #include <vector> #include <string> #include <iostream> #include <stdlib.h> #include <time.h> #define lowbit(x) (x&(-x)) #define ll long long #define lson l , m , rt << 1 #define rson m + 1 , r , rt << 1 | 1 #define ls son[0][rt] #define rs son[1][rt] #define new_edge(a,b,c) edge[tot].t = b , edge[tot].v = c , edge[tot].next = head[a] , head[a] = tot ++ using namespace std; const int maxn = 111111 ; int son[2][maxn] , fa[maxn] , size[maxn] ; ll val[maxn] , sum[maxn] ; int pos[maxn] ; int new_node ( int _val , int rt ) { if ( !rt ) return 0 ; sum[rt] = val[rt] = _val ; size[rt] = 1 ; fa[rt] = son[0][rt] = son[1][rt] = 0 ; return rt ; } struct Edge { int t , next , v ; } edge[maxn<<1] ; int head[maxn] , tot ; ll m , num[maxn] ; void push_up ( int rt ) { size[rt] = size[ls] + size[rs] + 1 ; sum[rt] = val[rt] + sum[ls] + sum[rs] ; } void rot ( int rt , int c ) { int y = fa[rt] , z = fa[y] ; son[!c][y] = son[c][rt] , fa[son[c][rt]] = y ; son[c][rt] = y , fa[y] = rt ; fa[rt] = z ; son[y==son[1][z]][z] = rt ; push_up ( y ) ; } void splay ( int rt ) { while ( fa[rt] ) { int y = fa[rt] , z = fa[y] ; if ( !fa[y] ) rot ( rt , rt == son[0][y] ) ; else { int c = ( rt == son[0][y] ) , d = ( y == son[0][z] ) ; if ( c == d ) rot ( y , c ) , rot ( rt , c ) ; else rot ( rt , c ) , rot ( rt , d ) ; } } push_up ( rt ) ; } void insert ( int rt , int y ) { if ( val[rt] <= val[y] ) { if ( !rs ) { rs = y , fa[y] = rt ; push_up ( rt ) ; return ; } insert ( rs , y ) ; } else { if ( !ls ) { ls = y , fa[y] = rt ; push_up ( rt ) ; return ; } insert ( ls , y ) ; } push_up ( rt ) ; } void print ( int rt ) { if ( !rt ) return ; printf ( "rt = %d , fa = %d , sum = %I64d\n" , rt , fa[rt] , sum[rt] ) ; printf ( "ls = %d , rs = %d , val = %I64d , size = %d\n" , ls , rs , val[rt] , size[rt] ) ; print ( ls ) ; print ( rs ) ; } void join ( int& x , int y ) { if ( son[0][y] ) join ( x , son[0][y] ) ; if ( son[1][y] ) join ( x , son[1][y] ) ; new_node ( val[y] , y ) ; insert ( x , y ) ; splay ( y ) ; x = y ; } ll ans = 0 ; int cnt ( int rt , ll now , int k ) { if ( now + sum[ls] + val[rt] <= m ) { if ( !rs ) return k + size[rt] ; else return cnt ( rs , now + sum[ls] + val[rt] , k + size[ls] + 1 ) ; } else { if ( !ls ) return k ; else return cnt ( ls , now , k ) ; } } void dfs ( int u ) { int i ; int temp = u ; for ( i = head[u] ; i != -1 ; i = edge[i].next ) { int v = edge[i].t ; dfs ( v ) ; v = pos[v] ; if ( size[temp] > size[v] ) join ( temp , v ) ; else join ( v , temp ) , temp = v ; } int k = cnt ( temp , 0 , 0 ) ; ans = max ( ans , num[u] * k ) ; pos[u] = temp ; } int main() { int n , i , j , k ; while ( scanf ( "%d%d" , &n , &m ) != EOF ) { memset ( head , -1 , sizeof ( head ) ) ; ans = tot = 0 ; int rt ; for ( i = 1 ; i <= n ; i ++ ) { int a , b , c ; pos[i] = i ; scanf ( "%d%d%d" , &a , &b , &c ) ; if ( a == 0 ) rt = i ; new_edge ( a , i , 0 ) ; num[i] = c ; new_node ( b , i ) ; } dfs ( rt ) ; printf ( "%lld\n" , ans ) ; } return 0; }