https://www.luogu.com.cn/problem/P4178
思路:
对u遍历每一棵子树,计算出dis,并询问前面子树有多少个点深度小于等于k−d[i] ,查询有多少值<=k-d[i]的,树状数组就好了。询问结束后把这个子树的答案累加进数组。遍历完所有子树后清空当前u的所有子树答案。注意是Dfs清空。
点分治后复杂度O(nlognlogn)
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<map>
#include<set>
#include<cstdio>
#include<algorithm>
#define debug(a) cout<<#a<<"="<<a<<endl;
#define lowbit(x) x&(-x)
using namespace std;
const int maxn=4e4+1000;
typedef int LL;
inline LL read(){LL x=0,f=1;char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;}
LL n,k;
struct Edge{LL to,w;};
vector<Edge>g[maxn];
LL tree[maxn],d[maxn];
LL rt,siz[maxn],son[maxn],sum;
bool vis[maxn];
void add(LL x,LL d){ while(x<maxn){tree[x]+=d;x+=lowbit(x);}}
LL getsum(LL x){LL sum=0;while(x>0){sum+=tree[x];x-=lowbit(x);}return sum;}
void getroot(LL u,LL fa){
siz[u]=1;son[u]=0;
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i].to;
if(vis[v]||v==fa) continue;
getroot(v,u);
siz[u]+=siz[v];
if(siz[v]>son[u]) son[u]=siz[v];
}
son[u]=max(son[u],sum-siz[u]);
if((son[u]<<1)<=sum) rt=u;
}
void dfs_res(LL u,LL fa,LL &res){
if(d[u]<=k) res+=getsum(k-d[u])+1;///+1是本身的链
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i].to;LL cost=g[u][i].w;
if(vis[v]||v==fa) continue;
d[v]=d[u]+cost;
dfs_res(v,u,res);
}
}
void dfs_update(LL u,LL fa,LL val){
if(d[u]<=k&&d[u]) add(d[u],val);
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i].to;
if(vis[v]||v==fa) continue;
dfs_update(v,u,val);
}
}
LL solve(LL u,LL fa){
d[u]=0;
LL res=0;
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i].to;LL w=g[u][i].w;
if(vis[v]) continue;
d[v]=w;
dfs_res(v,u,res);
dfs_update(v,u,1);///添加子树信息
}
dfs_update(u,fa,-1);///清空树状数组
return res;
}
LL divide(LL u,LL fa){
LL ans=0;
vis[u]=true;
ans+=solve(u,fa);///先处理以u为根的树
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i].to;LL w=g[u][i].w;
if(vis[v]) continue;
rt=0;///重心重新置为0
son[rt]=sum=siz[v];
getroot(v,0);
getroot(rt,0);
ans+=divide(rt,u);///分治下去
}
return ans;
}
int main(void){
n=read();
for(LL i=1;i<=n-1;i++){
LL u,v,w;u=read();v=read();w=read();///w++;///防止树状数组统计边权为0
g[u].push_back({v,w});
g[v].push_back({u,w});
}
k=read();///k++;
son[0]=sum=n;
getroot(1,0);
getroot(rt,0);
LL ans=0;
ans=divide(rt,0);
printf("%d\n",ans);
return 0;
}