题意
4908 Race 0x49「数据结构进阶」练习
描述
给定一棵 N 个节点的树,每条边带有一个权值。
求一条简单路径,路径上各条边的权值和等于K,且路径包含的边的数量最少。
输入格式
第一行两个整数 N, K。
第2~N行每行三个整数x,y,z,表示一条无向边的两个端点x,y和权值z,点的编号从0开始。
输出格式
一个整数,表示最少边数量。如果不存在满足要求的路径,输出-1。
样例输入
4 3 0 1 1 1 2 2 1 3 4
样例输出
2
数据范围与约定
- N <= 200000, K <= 1000000
来源
IOI2011
分析
训练指南的配套代码有问题,双指针绝对是错的。学习了hzwer的做法。
开一个100W的数组t,t[i]表示权值为i的路径最少边数
找到重心分成若干子树后, 得出一棵子树的所有点到根的权值和x,到根c条边,用t[k-x]+c更新答案,全部查询完后
然后再用所有c更新t[x]
这样可以保证不出现点分治中的不合法情况
把一棵树的所有子树搞完后再遍历所有子树恢复T数组,如果用memset应该会比较慢
时间复杂度\(O(n \log^2 n)\)
代码
网上好多代码,点分治递归的时候没有重新统计以重心为根的整棵树每个节点的size,按道理找出来的重心是错的,复杂度没有保证。
但是出题人根本就不会卡这种东西,所以不用写了,反而常数小。
#include<bits/stdc++.h>
#define rg register
#define il inline
#define co const
template<class T>il T read(){
rg T data=0,w=1;rg char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') w=-1;ch=getchar();}
while(isdigit(ch)) data=data*10+ch-'0',ch=getchar();
return data*w;
}
template<class T>il T read(rg T&x) {return x=read<T>();}
typedef long long ll;
using namespace std;
typedef pair<int,int> pii;
co int N=2e5+1;
int n,k,ans,sum,root,max_size;
vector<pii> e[N];
int t[1000001],s[N],d[N],c[N],v[N];
void getroot(int x){
v[x]=1,s[x]=1;
int max_part=0;
for(int i=0,y;i<e[x].size();++i){
if(v[y=e[x][i].first]) continue;
getroot(y);
s[x]+=s[y],max_part=max(max_part,s[y]);
}
max_part=max(max_part,sum-s[x]);
if(max_part<max_size) root=x,max_size=max_part;
v[x]=0;
}
void cal(int x){
v[x]=1;
if(d[x]<=k) ans=min(ans,c[x]+t[k-d[x]]);
for(int i=0,y;i<e[x].size();++i){
if(v[y=e[x][i].first]) continue;
c[y]=c[x]+1,d[y]=d[x]+e[x][i].second;
cal(y);
}
v[x]=0;
}
void add(int x,int flag){
v[x]=1;
if(d[x]<=k){
if(flag) t[d[x]]=min(t[d[x]],c[x]);
else t[d[x]]=n;
}
for(int i=0,y;i<e[x].size();++i){
if(v[y=e[x][i].first]) continue;
add(y,flag);
}
v[x]=0;
}
void work(int x){
v[x]=1,t[0]=0;
for(int i=0,y;i<e[x].size();++i){
if(v[y=e[x][i].first]) continue;
c[y]=1,d[y]=e[x][i].second;
cal(y),add(y,1);
}
for(int i=0,y;i<e[x].size();++i){
if(v[y=e[x][i].first]) continue;
add(y,0);
}
for(int i=0,y;i<e[x].size();++i){
if(v[y=e[x][i].first]) continue;
sum=max_size=s[y];
getroot(y),work(root);
}
}
int main(){
// freopen(".in","r",stdin),freopen(".out","w",stdout);
read(n),read(k);
fill(t+1,t+k+1,n);
for(int i=1,u,v,w;i<n;++i){
read(u),read(v),read(w);
e[++u].push_back(pii(++v,w)),e[v].push_back(pii(u,w));
}
ans=sum=max_size=n;
getroot(1),work(root);
if(ans==n) puts("-1");
else printf("%d\n",ans);
return 0;
}