链接:https://codeforces.com/contest/161/problem/D
题意:给一个树,求距离恰好为$k$的点对是多少
题解:对于一个树,距离为$k$的点对要么经过根节点,要么跨过子树的根节点,于是考虑树分治
用类似poj1741的想法,可以推出:
对于任意一棵子树,其根节点记为$C$,其子树中:
记距离$C$距离之和为$k$的点对数量$S_{c}$
记$C$儿子节点$C_1...C_n$的子树中,距离$C_i$距离为$k-2$的点对数量为$S'_{c_i}$
其符合条件的点对数量即为$S_{c}-\sum_1^n S'_{c_i}$
(网上这题,主流的树分治写法好像不是这个...有些看不懂啊....)
树上点分治参考我之前的题解:https://www.cnblogs.com/nervendnig/p/10106333.html
速度还是很可以的
相比dp的话,dp收到$K$大小的限制,如果$K$的大小和N同级别,就很难朴素的DP了,可能就要考虑树上倍增DP(实际上好像不能倍增)
而分治显然并不受限制
具体参见代码:
#include <bits/stdc++.h>
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define fi first
#define se second
#define mp make_pair
#define pii pair<int,int>
#define all(x) x.begin(),x.end()
#define IO ios::sync_with_stdio(false)
#define rep(ii,a,b) for(int ii=a;ii<=b;++ii)
#define per(ii,a,b) for(int ii=b;ii>=a;--ii)
#define forn(x,i) for(int i=head[x];i;i=e[i].next)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#define inline inline __attribute__( \
(always_inline, __gnu_inline__, __artificial__)) \
__attribute__((optimize("Ofast"))) __attribute__((target("sse"))) __attribute__((target("sse2"))) __attribute__((target("mmx")))
using namespace std;
#define tpyeinput int
char nc() {static char buf[1000000],*p1=buf,*p2=buf;return p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++;}
void read(tpyeinput &sum) {register char ch=nc();int flag=1;sum=0;while(ch<'0'||ch>'9') {if(ch=='-') flag=-1;ch=nc();}while(ch>='0'&&ch<='9') sum=(sum<<3)+(sum<<1)+(ch-48),ch=nc();sum*=flag;}
void read(tpyeinput &num1,tpyeinput &num2) {read(num1);read(num2);}
const int maxn=1e5+10,maxm=2e5+10;
const int INF=0x3f3f3f3f;
const int mod=1e9+7;
const double PI=acos(-1.0);
//head
int casn,n,m,k,mid,allnode;
struct node {int to,next;}e[maxm];int head[maxn],nume;
void add(int a,int b){e[++nume]=(node){b,head[a]};head[a]=nume;}
int sz[maxn],maxt,deep[maxn],vis[maxn],cnt;
ll ans;
void getc(int now,int pre){
sz[now]=1;
for(int i=head[now];i;i=e[i].next){
if(e[i].to==pre||vis[e[i].to])continue;
getc(e[i].to,now);
sz[now]+=sz[e[i].to];
}
int tmp=max(sz[now]-1,allnode-sz[now]);
if(maxt>tmp) maxt=tmp,mid=now;
}
void dfs(int now,int pre,int len,int dis){
deep[++cnt]=dis;
if(dis>=len)return;
for(int i=head[now];i;i=e[i].next){
if(e[i].to==pre||vis[e[i].to])continue;
dfs(e[i].to,now,len,dis+1);
}
}
ll cal(int rt,int pre,int len){
if(len<=0) return len==0;
cnt=0;
dfs(rt,pre,len,0);
ll res=0;
int num[507]{};
rep(i,1,cnt) num[deep[i]]++;
rep(i,1,cnt) res+=num[len-deep[i]];
return res;
}
void dc(int rt){
vis[rt]=1;
ans+=cal(rt,0,k);
for(int i=head[rt];i;i=e[i].next){
if(vis[e[i].to]) continue;
ans-=cal(e[i].to,rt,k-2);
allnode=sz[e[i].to],maxt=n;
getc(e[i].to,rt);dc(mid);
}
}
int main() {
//#define test
#ifdef test
auto _start = chrono::high_resolution_clock::now();
freopen("in.txt","r",stdin);freopen("out.txt","w",stdout);
#endif
read(n,k);
int a,b;
rep(i,1,n-1){
read(a,b);
add(a,b);add(b,a);
}
allnode=n;
maxt=INF;
getc(1,0);
dc(mid);
printf("%lld",ans/2);
#ifdef test
auto _end = chrono::high_resolution_clock::now();
cerr << "elapsed time: " << chrono::duration<double, milli>(_end - _start).count() << " ms\n";
fclose(stdin);fclose(stdout);system("out.txt");
#endif
return 0;
}