【BZOJ3522&BZOJ4543】Hotel加强版(长链剖分,树形DP)

题意:求一颗树上三点距离两两相等的三元组对数

n<=1e5

思路:From https://blog.bill.moe/bzoj4543-hotel/

f[i][j]表示以i为根的子树中距离i为j的点的个数

g[i][j]表示以i为根的子树中两点距离他们的lca为d,lca距离i为d-j的两点对数

g[i][j]找到一个子树外的f[i][j]就对答案有贡献

朴素的方程为:设v为u的一个儿子

ans+=f[u][j]*g[v][j+1]+g[u][j]*f[y][j-1]

g[u][j+1]+=f[u][j+1]*f[v][j]

g[u][j-1]+=g[v][j]

f[u][j+1]+=f[v][j]

显然f[i][j]只和深度有关,且f[u]的[1,len[u]]这一段是所有f[v]的[0,len[u]-1]右移一位之和

为了防止同一个子树中的信息算多了,先算ans部分再执行后面三步更新

指针的写法我完全是抄的

 #include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef pair<ll,ll> Pll;
typedef vector<int> VI;
typedef vector<PII> VII;
typedef pair<ll,int>P;
#define N 100010
#define M 200010
#define fi first
#define se second
#define MP make_pair
#define pi acos(-1)
#define mem(a,b) memset(a,b,sizeof(a))
#define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++)
#define per(i,a,b) for(int i=(int)a;i>=(int)b;i--)
#define lowbit(x) x&(-x)
#define Rand (rand()*(1<<16)+rand())
#define id(x) ((x)<=B?(x):m-n/(x)+1)
#define ls p<<1
#define rs p<<1|1 const ll MOD=1e9+,inv2=(MOD+)/;
double eps=1e-;
int INF=<<;
ll inf=5e13;
int dx[]={-,,,};
int dy[]={,,-,}; int head[M],vet[M],nxt[M],tot;
int len[N],son[N];
ll tmp[N*],*f[N],*g[N],*now=tmp,ans; int read()
{
int v=,f=;
char c=getchar();
while(c<||<c) {if(c=='-') f=-; c=getchar();}
while(<=c&&c<=) v=(v<<)+v+v+c-,c=getchar();
return v*f;
} void add(int a,int b)
{
nxt[++tot]=head[a];
vet[tot]=b;
head[a]=tot;
} void dfs(int u,int fa,int d)
{
len[u]=;
int e=head[u];
while(e)
{
int v=vet[e];
if(v!=fa)
{
dfs(v,u,d+);
if(len[v]>len[son[u]])
{
son[u]=v;
len[u]=len[v]+;
}
}
e=nxt[e];
}
} void solve(int u,int fa)
{
if(son[u])
{
f[son[u]]=f[u]+;
g[son[u]]=g[u]-;
solve(son[u],u);
}
f[u][]=;
ans+=g[u][];
int e=head[u];
while(e)
{
int v=vet[e];
if(v!=fa&&v!=son[u])
{
f[v]=now;
now+=(len[v]<<)+;
g[v]=now;
now+=(len[v]<<)+;
solve(v,u);
rep(j,,len[v])
{
if(j) ans+=f[u][j-]*g[v][j];
ans+=g[u][j+]*f[v][j];
}
rep(j,,len[v])
{
g[u][j+]+=f[u][j+]*f[v][j];
if(j) g[u][j-]+=g[v][j];
f[u][j+]+=f[v][j];
}
}
e=nxt[e];
}
}
int main()
{
int n=read();
tot=;
rep(i,,n-)
{
int x=read(),y=read();
add(x,y);
add(y,x);
}
len[]=-;
dfs(,,);
ans=;
f[]=now,now+=(len[]<<)+,g[]=now,now+=(len[]<<)+;
solve(,);
printf("%lld\n",ans);
return ;
}
上一篇:[LINUX] 查看连接数和IO负载


下一篇:React Native 环境配置