树形枚举--搜索
题目描述:
给你一棵树,要在一条简单路径上选3个不同的点构成一个集合,问能构成多少个不同的集合。
解法:
枚举所有结点,假设某个结点有n棵子树,每棵子树的结点个数分别为s1,s2,````sn.那么在选中该结点后,剩下的两个结点从子树上选,考虑顺序,则有方案数ans = s1*(sum(si) - s1) + s2*(sum(si) - s2) + ``` + sn*(sum(si) - sn),化简得ans = sum(si) ^ 2 - sum(si ^2) .实际上,另外两个点选了(1,2)和(2,1)对于集合{a,b,c}而言是一样的,所以方案数应该为ans/2。
而一个结点的子节点的总和为n-1。只要求sum(si^2)即可,由于是从某一点开始搜的,我们把这个点想成了根结点,实际上对于除这个点以外的其他点,除了按搜索顺序认为的子树外,剩下的结点是它的另一棵子树。
代码实现:
ll ans=0;
int dfs(int x)//返回该结点的某棵子树的结点个数
{
vis[x] =1;
ll sq=0;//”子树”结点个数的平方和
int tot =0;//”子树”结点个数的总和
for(int i=0; i<g[x].size(); ++i)
{
if(!vis[g[x][i]])
{
int son = dfs(g[x][i]);//子结点的个数
tot += son;
sq += (ll)son*son;
}
}
sq += (ll)(n-tot-1)*(n-tot-1);
ans += ((ll)(n-1)*(n-1) - sq)/2;
return tot+1;//当自己作为子树时,还要加上自己,所以加1.
}
另一种解法:
直接考虑顺序计数,有方案数ans =s1*(s2+s3+````+sn)+s2*(s3+s4+```sn) +```` + s[n-1]*sn. 则有设sum[i]为前i项和,则有ans = s1*(n-1-sum[1]) + s2*(n-1-sum[2]) +````+s[n-1]*(n-1-sum[n-1]) + sn *0.
代码实现:
ll ans;
int dfs(int x)
{
vis[x] =1;
int tmp=0;
for(int i=0; i<g[x].size(); ++i)
{
if(!vis[g[x][i]])
{
int son = dfs(g[x][i]);//当前分支的儿子个数
tmp += son;//已经求出的儿子个数,相当于sum[i]
ans += (ll)(n-1-tmp)*son;
}
}
return tmp+1;
}
然后此题的正解就是C(n,3)- ans.C(n,3)表示从n个点中选3个点,C(n,3) = n*(n-1)*(n-2)/6
贴代码1:
#pragma comment(linker, "/STACK:16777216")
#include<cstdio>
#include<vector>
#define N 100010
using namespace std;
typedef long long int ll;
vector<int> g[N];
bool vis[N];
int n;
ll ans;
int dfs(int x)
{
vis[x] =;
ll sq=;//子节点个数的平方和
int tot =;//子节点个数之和
for(int i=; i<g[x].size(); ++i)
{
if(!vis[g[x][i]])
{
int son = dfs(g[x][i]);//子节点的个数
tot += son;
sq += (ll)son*son;
}
}
sq += (ll)(n-tot-)*(n-tot-);
ans += ((ll)(n-)*(n-) - sq)/;
return tot+;
}
int main()
{
freopen("1010.in","r",stdin);
while(~scanf("%d",&n))
{
for(int i=; i<=n; ++i)
{
g[i].clear();
vis[i] =;
}
for(int i=; i<n; ++i)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
ans =;
dfs();
ll sum = (ll)n*(n-)*(n-)/;
printf("%I64d\n",sum-ans);
}
return ;
}
贴代码2:
#pragma comment(linker, "/STACK:16777216")
#include<cstdio>
#include<vector>
#define N 100010
using namespace std;
typedef long long int ll;
vector<int> g[N];
bool vis[N];
int n;
ll ans;
int dfs(int x)
{
vis[x] =;
int tmp=;
for(int i=; i<g[x].size(); ++i)
{
if(!vis[g[x][i]])
{
int son = dfs(g[x][i]);//当前分支的儿子个数
tmp += son;//已经求出的儿子个数
ans += (ll)(n--tmp)*son;
}
}
return tmp+;
}
int main()
{
// freopen("1010.in","r",stdin);
while(~scanf("%d",&n))
{
for(int i=; i<=n; ++i)
{
g[i].clear();
vis[i] =;
}
for(int i=; i<n; ++i)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
ans =;
dfs();
ll sum = (ll)n*(n-)*(n-)/;
printf("%I64d\n",sum-ans);
}
return ;
}