初学树分治的学习笔记。
分治本质上都是一样的,就是要把原问题分割成几个规模更小的问题分别求解然后再合并上去。而由于树它本身具有很强的递归特性(笼统的说就是树的部分还是树),于是就使得树上分治成为可能。
树分治一般遵循以下步骤:
第一步,把树中所有节点全部找出来,然后把它们近乎于暴力地统计一遍答案(此步骤一般是 \(O(N)/O(NlogN)\)),而且此步骤只统计所有经过根节点的路径;第二步就是找到当前这棵树的重心(这样可以使得分治下去的问题规模得到保障从而保证复杂度),然后递归处理即可。整体复杂度是\(O(NlogN)\)左右。
按照这个思路就可以写这道题了。
#include<cstdio>
#include<cstring>
#include<algorithm>
//#define zczc
#define ll long long
using namespace std;
const int N=40010;
const int maxn=1e9;
using namespace std;
inline void read(int &wh){
wh=0;int f=1;char w=getchar();
while(w<'0'||w>'9'){if(w=='-')f=-1;w=getchar();}
while(w<='9'&&w>='0'){wh=wh*10+w-'0';w=getchar();}
wh*=f;return;
}
inline int max(int s1,int s2){
return s1<s2?s2:s1;
}
int m,want;
struct edge{
int t,v,next;
}e[N<<1];
int head[N],esum;
inline void add(int fr,int to,int val){
esum++;
e[esum].t=to;
e[esum].v=val;
e[esum].next=head[fr];
head[fr]=esum;
return;
}
ll ans;
bool vis[N];
//solve pairs
int cnt,dis[N];
void dfs(int wh,int fa,int ndis){
dis[++cnt]=ndis;
for(int i=head[wh],th;i;i=e[i].next){
th=e[i].t;
if(vis[th]||th==fa)continue;
dfs(th,wh,ndis+e[i].v);
}
}
int work(int wh,int val){
//init
cnt=0;
dfs(wh,0,val);
sort(dis+1,dis+cnt+1);
int an=0;
for(int i=cnt,j=0;i;i--){
//对于dis[i]找到有多少个符合条件的j与之对应
while(j==0||(j<=cnt&&dis[j]+dis[i]<=want))j++;j--;
an+=j;if(i<=j)an--;
}
return an/2;
}
//get new root
int hsize;//whole size
int amax;//ans maxn
int nroot,size[N];
void dfs2(int wh,int fa){
size[wh]=1;
int nmax=0;
for(int i=head[wh],th;i;i=e[i].next){
th=e[i].t;
if(vis[th]||th==fa)continue;
dfs2(th,wh);
size[wh]+=size[th];
nmax=max(nmax,size[th]);
}
nmax=max(nmax,hsize-size[wh]);
if(nmax<amax){
amax=nmax;
nroot=wh;
}
return;
}
void solve_root(int wh){
amax=maxn;
hsize=size[wh];
dfs2(wh,0);
return;
}
void solve(int wh){
//printf("now:%d %d\n",wh,ans);
ans+=work(wh,0);
vis[wh]=true;
for(int i=head[wh],th;i;i=e[i].next){
th=e[i].t;
if(vis[th])continue;
ans-=work(th,e[i].v);
solve_root(th);
solve(th);
}
return;
}
void init(){
memset(head,0,sizeof(head));
esum=0;
memset(vis,0,sizeof(vis));
ans=0;
return;
}
signed main(){
#ifdef zczc
freopen("in.txt","r",stdin);
#endif
read(m);
int s1,s2,s3;
for(int i=1;i<m;i++){
read(s1);read(s2);read(s3);
add(s1,s2,s3);add(s2,s1,s3);
}
read(want);
hsize=m;
solve(1);
printf("%lld\n",ans);
return 0;
}