好像用到一些高中数学知识......
满分做法:
case 0:已知a数组求b数组
因为是树状结构,设当前节点x 儿子to
我们从任意一点出发可求出b[root]来,之后我们可以通过寻找两两相连节点的关系来O(n)推出全部的b
我们发现x与y之间只有一条边的贡献不同,就是他们相连的边
(边的贡献即该边节点所在子树通过该点的a权值和)
那么我们就轻松搞掉了......
case 1:已知b求a
设sum[i]为以i为根的子树的a值和,all为总值。
我们首先可以发现b[x]-b[to]的差值可以用sum[to]表示
两者之间的差值其实就是all-sum[to]与sum[to]的差值
那么我们用起高中数学的类似等差数列的东西????
从树上遍历一番将所有的边的x与to的上述式子加和
然后注意记录每个b数组的系数,显然如果是叶子节点为-1,其他也要记录
我们发现化简后
ss[1]*b[1]+........=2*(sum[2]+......sum[n])-all*(n-1)
(ss是系数,all是a的总值)
其中(sum[2]......)总值是2*b[1],之后DFS统计就行了
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<string> 5 #include<algorithm> 6 #include<cmath> 7 #include<stack> 8 #include<map> 9 #include<queue> 10 #define ps push_back 11 #define MAXN 210001 12 #define ll long long 13 using namespace std; 14 ll a[MAXN],b[MAXN]; 15 ll T,n; 16 ll head[MAXN],tot; 17 struct node{ll to,n;}e[2*MAXN]; 18 void add(ll u,ll v) 19 { 20 e[++tot].to=v;e[tot].n=head[u];head[u]=tot; 21 } 22 bool vis[MAXN];ll sum[MAXN]; 23 ll fa[MAXN];ll all; 24 void DFS(ll x) 25 { 26 vis[x]=1; 27 all+=a[x]; 28 sum[x]=a[x]; 29 for(ll i=head[x];i;i=e[i].n) 30 { 31 ll to=e[i].to; 32 if(vis[to]==1)continue; 33 fa[to]=x; 34 DFS(to); 35 sum[x]+=sum[to]; 36 b[1]+=sum[to]; 37 } 38 } 39 void DFS_findans(ll x) 40 { 41 vis[x]=1; 42 for(ll i=head[x];i;i=e[i].n) 43 { 44 ll to=e[i].to; 45 if(vis[to]==1)continue; 46 b[to]=b[x]-2*sum[to]+all; 47 DFS_findans(to); 48 } 49 } 50 ll orz; 51 void work_a() 52 { 53 memset(sum,0,sizeof(sum)); 54 all=0; 55 memset(vis,0,sizeof(vis)); 56 DFS(1); 57 memset(vis,0,sizeof(vis)); 58 DFS_findans(1); 59 for(ll i=1;i<=n;++i) 60 { 61 printf("%lld ",b[i]); 62 } 63 cout<<endl; 64 } 65 ll chu[MAXN]; 66 ll ss[MAXN]; 67 void DFS_B(ll x) 68 { 69 vis[x]=1; 70 //printf("x=%lld\n",x); 71 for(ll i=head[x];i;i=e[i].n) 72 { 73 ll to=e[i].to; 74 if(vis[to]==1)continue; 75 fa[to]=x; 76 ss[x]++; 77 ss[to]--; 78 DFS_B(to); 79 } 80 } 81 void DFS_find(ll x) 82 { 83 vis[x]=1;a[x]=sum[x]; 84 for(ll i=head[x];i;i=e[i].n) 85 { 86 ll to=e[i].to; 87 if(vis[to])continue; 88 sum[to]=(b[x]-b[to]+all)/2; 89 DFS_find(to); 90 a[x]-=sum[to]; 91 } 92 } 93 void work_b() 94 { 95 memset(a,0,sizeof(a)); 96 memset(sum,0,sizeof(sum)); 97 memset(fa,0,sizeof(fa)); 98 all=0; 99 memset(vis,0,sizeof(vis)); 100 memset(ss,0,sizeof(ss)); 101 DFS_B(1); 102 103 ll pss=0; 104 for(ll i=1;i<=n;++i) 105 { 106 //printf("ss%lld b%lld\n",ss[i],b[i]); 107 pss+=ss[i]*b[i]; 108 } 109 all=(2*b[1]-pss)/(n-1); 110 sum[1]=all; 111 //printf("all=%lld\n",all); 112 memset(vis,0,sizeof(vis)); 113 114 DFS_find(1); 115 116 for(ll i=1;i<=n;++i) 117 { 118 printf("%lld ",a[i]); 119 } 120 cout<<endl; 121 } 122 int main() 123 { 124 scanf("%lld",&T); 125 while(T--) 126 { 127 memset(a,0,sizeof(a)); 128 memset(b,0,sizeof(b)); 129 memset(head,0,sizeof(head)); 130 memset(vis,0,sizeof(vis)); 131 memset(chu,0,sizeof(chu)); 132 tot=0; 133 scanf("%lld",&n); 134 for(ll i=1;i<=n-1;++i) 135 { 136 ll x,y; 137 scanf("%lld%lld",&x,&y); 138 add(x,y);add(y,x); 139 chu[x]++;chu[y]++; 140 } 141 scanf("%lld",&orz); 142 if(orz==0) 143 { 144 for(ll i=1;i<=n;++i) 145 { 146 scanf("%lld",&a[i]); 147 } 148 work_a(); 149 } 150 else 151 { 152 for(ll i=1;i<=n;++i) 153 { 154 scanf("%lld",&b[i]); 155 } 156 work_b(); 157 } 158 } 159 }View Code