有几档暴力不会写,巨丢人
\(m=2\) 的话两个人之间的距离会覆盖整棵树上所有可能的路径,所以就是求所有树上路径长度的总和
成链且 \(m\) 为奇数的话,集中点肯定是中位数那个点
考场上想偏了,只会用这个性质求一些给定的人应该集中在哪个点
但实际上可以枚举中位数这个点,求出一共有多少匹配的方案
然后正解
点并不好考虑,所以考虑边
如果一条边两边人数不等那数量较少的那些人肯定都得经过这条边
令 \(s\) 为这条边一边的人数
于是一条边的贡献为 \(\sum\limits_{i=1}^{m-1}\binom{s}{i}\binom{n-s}{m-i}min(i, m-i)\)
直接求复杂度会炸,于是转化一下,先考虑弄掉那个min
等价于 \(\sum\limits_{i=1}^{\frac{m-1}{2}}\binom{s}{i}\binom{n-s}{m-i}i + \binom{n-s}{i}\binom{s}{m-i}i\)
令 \(k=\frac{m-1}{2}\)
观察这个式子 \(\sum\limits_{i=1}^{k}\binom{s}{i}\binom{n-s}{m-i}i\),试着把外面的 \(i\) 去掉
于是令 \(G(s)=\sum\limits_{i=1}^{k}\binom{s-1}{i-1}\binom{n-s}{m-i}\),则原式等于 \(s*G(s)\)
- 一个全是组合数数的式子想转化或者 \(O(n)\) 递推的话貌似可以根据式子给它一个组合意义
考虑组合意义,即为在 \(n-1\) 个物品里选 \(m-1\) 个,要求前 \(s-1\) 中最多能选 \(k-1\) 个的方案数
推到 \(G(s+1)\) 的话会少了前 \(s-1\) 中选了 \(k-1\) 个,且第 \(s+1\) 个也被选了的情况
于是可以 \(O(n)\) 递推
与这个类似的另一部分实际上就是 \(G(n-s)\),不用另算了
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
// #define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48), c=getchar();
return ans*f;
}
int n, m;
int head[N], size;
const ll mod=1000000007;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;}
namespace force{
int vis, dp[30], siz[30], minn;
ll ans;
void dfs1(int u) {
siz[u]=(vis&(1<<(u-1)))?1:0;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs1(v);
siz[u]+=siz[v];
dp[u]+=dp[v]+siz[v];
}
}
void dfs2(int u, int sum) {
// cout<<"dfs2: "<<u<<' '<<sum<<endl;
minn=min(minn, sum+dp[u]);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs2(v, sum+dp[u]-dp[v]-siz[v]+(m-siz[v]));
}
}
void solve() {
memset(head, -1, sizeof(head));
for (int i=2; i<=n; ++i) add(read(), i);
int lim=1<<n;
for (int s=1,s2,cnt; s<lim; ++s) {
s2=s; cnt=0;
do {++cnt; s2&=s2-1;} while (s2) ;
if (cnt!=m) goto jump;
vis=s;
// cout<<"s: "<<bitset<5>(s)<<endl;
memset(dp, 0, sizeof(dp));
// memset(siz, 0, sizeof(siz));
dfs1(1); minn=INF; dfs2(1, 0);
// cout<<"siz: "; for (int i=1; i<=n; ++i) cout<<siz[i]<<' '; cout<<endl;
// cout<<"dp: "; for (int i=1; i<=n; ++i) cout<<dp[i]<<' '; cout<<endl;
// cout<<"minn: "<<minn<<endl;
ans=(ans+minn)%mod;
jump: ;
}
printf("%lld\n", ans);
exit(0);
}
}
namespace task1{
int siz[N];
ll fac[N], inv[N], ans;
inline ll C(int n, int k) {return n<k?0ll:fac[n]*inv[n-k]%mod*inv[k]%mod;}
void dfs(int u) {
siz[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs(v);
siz[u]+=siz[v];
}
}
void solve() {
memset(head, -1, sizeof(head));
for (int i=2; i<=n; ++i) add(read(), i);
dfs(1);
fac[0]=fac[1]=1; inv[0]=inv[1]=1;
for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
for (int i=1; i<n; ++i) {
for (int j=1; j<m; ++j) {
int s=siz[e[i].to];
ans=(ans+C(s, j)*C(n-s, m-j)%mod*min(j, m-j)%mod)%mod;
}
}
printf("%lld\n", ans);
exit(0);
}
}
namespace task{
int siz[N];
ll fac[N], inv[N], ans, G[N], H[N];
inline ll C(int n, int k) {return n<k?0ll:fac[n]*inv[n-k]%mod*inv[k]%mod;}
inline ll H2(int s) {
int k=(m-1)/2;
ll ans=0;
for (int i=1; i<=k; ++i) ans=(ans+C(n-s-1, i-1)*C(s, m-i)%mod)%mod;
return ans;
}
inline ll G2(int s) {
int k=(m-1)/2;
ll ans=0;
for (int j=1; j<=k; ++j) ans=(ans+C(s-1, j-1)*C(n-s, m-j)%mod)%mod;
return ans;
}
void dfs(int u) {
siz[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs(v);
siz[u]+=siz[v];
}
}
void solve() {
memset(head, -1, sizeof(head));
for (int i=2; i<=n; ++i) add(read(), i);
dfs(1);
fac[0]=fac[1]=1; inv[0]=inv[1]=1;
for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod;
G[1]=G2(1);
int k=(m-1)/2;
for (int i=1; i<n; ++i) G[i+1]=(G[i]-C(i-1, k-1)*C(n-i-1, m-k-1)%mod)%mod;
#if 0
cout<<"G: "; for (int i=1; i<=n; ++i) cout<<(G[i]+mod)%mod<<' '; cout<<endl;
for (int i=1; i<=n; ++i) {
ll t=0;
for (int j=1; j<=k; ++j)
t=(t+C(i-1, j-1)*C(n-i, m-j)%mod)%mod;
cout<<t<<' ';
} cout<<endl;
#endif
#if 1
// H[n-1]=H2(n-1);
// for (int i=n-1; i; --i) H[i-1]=(H[i]-C(n-i-1, k-1)*(i-1, m-k-1)%mod)%mod;
// for (int i=1; i<=n; ++i) H[i+1]=(H[i]+C(n-i-1, k-1)*(i-1, m-k-1)%mod)%mod;
// cout<<"H2: "; for (int i=1; i<=n; ++i) cout<<H2(i)<<' '; cout<<endl;
// cout<<"H: "; for (int i=1; i<=n; ++i) cout<<H[i]<<' '; cout<<endl;
#endif
for (int i=1; i<n; ++i) {
int s=siz[e[i].to];
// cout<<"H: "<<H2(s)<<endl;
ans=(ans+s*G[s]%mod+(n-s)*G[n-s]%mod+((m&1)?0:(C(s, m/2)*C(n-s, m/2)%mod*(m/2)%mod)))%mod;
}
printf("%lld\n", (ans%mod+mod)%mod);
exit(0);
}
}
signed main()
{
freopen("meeting.in", "r", stdin);
freopen("meeting.out", "w", stdout);
n=read(); m=read();
// force::solve();
task::solve();
return 0;
}