[dsu on tree] 2020CCPC长春F Strange Memory

首先考虑枚举lca的做法,对于每一个lca枚举其子树中所有节点,时间复杂度$O(n^2)$显然过不了

再思考发现这是一个针对子树的询问操作,考虑dsu on tree来统计答案

开一个新数组vec[x],其中x为权值,记录了所有权值为x的编号

那么只需要每次计算一颗新子树时,先累加答案,再更新vec数组(如果同时进行则可能会出现同一个子树间的节点被统计,然后他们的lca并不是我们枚举的lca)

但是这样时间复杂度反而变成了$O(n^2 logn)$

怎么办呢?考虑拆位。由于异或时每一位之间的运算是互相独立的(不存在进位借位等操作),所以我们对于每一位单独统计,就不用枚举编号啦(或者说枚举vec数组)

我们记数组cnt[x][0/1],表示在编号第bit位为0/1时,权值为x的子树中的节点数量

那么$ans = \sum $cnt[a[x]^a[lca]][(x>>bit)&1^1] 

 

这样的时间复杂度是$O(n \ logn \ logm)$的

我的统计方法是在主函数中枚举bit,每次进行一次dsu on tree

看了一眼某乎,貌似我这样的写法可能会T,更快的做法是扩充cnt,记录cnt[x][bit][0/1],在每次dsu的时候把全部算完...

#include<bits/stdc++.h>
using namespace std;
inline int read(){
  int ans=0,f=1;char chr=getchar();
  while(!isdigit(chr)){if(chr=='-')f=-1;chr=getchar();}
  while(isdigit(chr))ans=(ans<<3)+(ans<<1)+chr-48,chr=getchar();
  return ans*f;
}const int M = 2e5+5;
int n,m,a[M],cnt[(1<<20)+10000][2],tot,head[M],nxt[M],ver[M],sz[M],son[M],T,bit,b[M],ans,id[M],maxn;
inline void add(int x,int y){ver[++tot]=y,nxt[tot]=head[x],head[x]=tot;}
void dfs(int x,int fa){
  sz[x]=1;
  for(int i=head[x];i;i=nxt[i])
    if(ver[i]!=fa)
      dfs(ver[i],x),
      sz[x]+=sz[ver[i]],
      (sz[ver[i]]>sz[son[x]])&&(son[x]=ver[i]);
}
inline void Clear(){
  for(register int i=1;i<=T;++i)
    --cnt[b[i]][(id[i]>>bit)&1];
  T=0;
}
inline void Insert(int x,int lca){
  b[++T]=a[x],id[T]=x;
  if(x!=lca)ans+=cnt[a[x]^a[lca]][((x>>bit)&1)^1];
}
inline void Re_Add(int x,int fa){
  ++cnt[a[x]][(x>>bit)&1];
  for(register int i=head[x];i;i=nxt[i])
    if(ver[i]!=fa)Re_Add(ver[i],x);
}
void Add_ans(int x,int fa,int lca){
  Insert(x,lca);
  for(register int i=head[x];i;i=nxt[i])
    if(ver[i]!=fa)Add_ans(ver[i],x,lca);
}
void dsu(int x,int fa){
  for(int i=head[x];i;i=nxt[i])
    if(ver[i]!=son[x]&&ver[i]!=fa)dsu(ver[i],x),Clear();
  if(son[x])dsu(son[x],x);
  ++cnt[b[++T]=a[x]][(x>>bit)&1],id[T]=x;
  for(int i=head[x];i;i=nxt[i])
    if(ver[i]!=fa&&ver[i]!=son[x])Add_ans(ver[i],x,x),Re_Add(ver[i],x);
}
signed main(){
  n=read();
  for(int i=1;i<=n;i++)a[i]=read(),maxn=max(maxn,a[i]);
  for(int i=1,x,y;i<n;i++)add(x=read(),y=read()),add(y,x);
  dfs(1,0);
  long long res=0;
  for(bit=0;(1<<bit)<=maxn;++bit){
    ans=0,dsu(1,0),Clear();
    res+=ans*(1ll<<bit);
  }cout<<res;
  return 0;
}

 

上一篇:C练题笔记之:Leetcode-717. 1比特与2比特字符


下一篇:Java锁