UOJ#388. 【UNR #3】配对树 树链剖分+线段树

这道题卡常啊 !           

出题人说 $O(n \log^2 n)$ 可过,但我写了个 $O(n \log^2 n)$ 的树剖卡了半天常数.     

最暴力的做法:枚举区间,然后跑一个树形DP 来求最小匹配.     

显然,因为要求匹配值最小,所以一定是能匹配就先匹配.   

也就是说递归完 $x$ 的所有儿子后,$x$ 的每一个儿子最多只有 1 个点还没有匹配.      

这个时间复杂度是 $O(n^3)$ 的.    

然后我们对每一条边分别考虑:   

令 $v[x]$ 表示点 $x$ 到其父亲的边权(以 1 为根),那么 $v[x]$ 能产生贡献,当且仅当一个区间中 $x$ 子树中有奇数个点.   

这个很好理解,因为如果有奇数个点,就意味着 1 个点没有被匹配到,而需要向上延伸的 $x$ 的父亲,依此类推......       

那么就枚举右端点,然后令 $f[x][0/1]$ 分别表示多少个长度为偶数的区间满足在 $x$ 的子树中有偶数/奇数个点.      

由于要求区间长度是偶数,我们可以分别以 $1,2$ 为起点各跑一次,每次同时加入两个点来保证长度为偶数.      

考虑加入 $x,y$ 后的影响:

$x$ 到 $lca$ 与 $y$ 到 $lca$ (不包括 lca 这个点)的路径上 $f[x][0]=f[x][1]$,$f[x][1]=f[x][0]+1$             

不在 $x,y$ 路径上的点 $f[x][1]$ 不变,$f[x][0] \leftarrow f[x][0]+1$.     

这个暴力修改的话是 $O(n^2)$ 的,可以获得 $50$pts.  

满分算法的话就是用树链剖分+线段树来维护上面的东西.   

我们无外乎就是要支持:每个节点维护 $f[x][0],f[x][1]$,区间加,区间交换.  

然后定义标记 $(rev,x,y)$ 表示是否要交换 $f[x][0],f[x][1]$ 的值,交换后对 $f[x][0]$,$f[x][1]$ 分别加上 $x,y$.     

时间复杂度为 $O(n \log^2 n)$,但是会有点卡常.  

这里说几个卡常技巧: 

1. 读入优化  

2. 开 long long 要比取模快.  

3. 由于上述操作中每次加的数是 1 或 -1,所以这个标记可以直接开 int,然后区间和开 long long.    

code:    

#include <cstdio>
#include <ctime>
#include <cstring> 
#include <algorithm>      
#define N 100008  
#define ll long long 
#define mod 998244353
#define lson now<<1  
#define rson now<<1|1     
#define setIO(s) freopen(s".in","r",stdin)  
using namespace std;    
int edges,n,m,tim;    
int nd[N],f[N][2],fa[N];      
int hd[N],to[N<<1],nex[N<<1],val[N<<1]; 
int dep[N],a[N],size[N],top[N],son[N],dfn[N],bu[N];    
ll ans;   
struct data {
    int rev;  
    int vx,vy;  
    ll sx,sy,sum;    
    data(int rev=0,int vx=0,int vy=0):rev(rev),vx(vx),vy(vy){}  
}s[N<<2];   
inline void add(int u,int v,int c) {         
    nex[++edges]=hd[u];   
    hd[u]=edges,to[edges]=v,val[edges]=c;  
}            
void dfs(int x,int ff) {  
    size[x]=1;  
    fa[x]=ff,dep[x]=dep[ff]+1;        
    for(int i=hd[x];i;i=nex[i]) {              
        int y=to[i];  
        if(y==ff) continue;      
        nd[y]=val[i],dfs(y,x);    
        size[x]+=size[y];  
        if(size[y]>size[son[x]]) son[x]=y;    
    }
}
void dfs2(int x,int tp) {
    top[x]=tp;  
    dfn[x]=++tim;  
    bu[tim]=x;    
    if(son[x]) dfs2(son[x],tp); 
    for(int i=hd[x];i;i=nex[i]) 
        if(to[i]!=fa[x]&&to[i]!=son[x]) 
            dfs2(to[i],to[i]);   
}
inline int get_lca(int x,int y) { 
    while(top[x]!=top[y]) {
        dep[top[x]]>dep[top[y]]?x=fa[top[x]]:y=fa[top[y]];  
    }
    return dep[x]<dep[y]?x:y;   
}      
inline void pushup(int now) {
    s[now].sx=(ll)(s[lson].sx+s[rson].sx);  
    s[now].sy=(ll)(s[lson].sy+s[rson].sy);    
}
inline void mark_rev(int now) {
    swap(s[now].sx,s[now].sy);        
    swap(s[now].vx,s[now].vy);   
    s[now].rev^=1;     
}
inline void mark_add(int now,int vx,int vy) {
    if(vx) (s[now].sx+=(ll)vx*s[now].sum);   
    if(vy) (s[now].sy+=(ll)vy*s[now].sum);   
    if(vx) (s[now].vx+=vx);  
    if(vy) (s[now].vy+=vy);    
}
inline void pushdown(int now) {
    if(s[now].rev) {
        s[now].rev=0; 
        mark_rev(lson); 
        mark_rev(rson); 
    }   
    if(s[now].vx||s[now].vy) {
        mark_add(lson,s[now].vx,s[now].vy);  
        mark_add(rson,s[now].vx,s[now].vy);  
        s[now].vx=s[now].vy=0;   
    }
}
void build(int l,int r,int now) {
    s[now]=data(); 
    s[now].sx=0; 
    s[now].sy=0;      
    if(l==r) {
        s[now].sum=nd[bu[l]];   
        return; 
    }
    int mid=(l+r)>>1;  
    build(l,mid,lson),build(mid+1,r,rson);   
    s[now].sum=(ll)(s[lson].sum+s[rson].sum)%mod;   
}
void REV(int l,int r,int now,int L,int R) {
    if(l>=L&&r<=R) {
        mark_rev(now);   
        return;  
    }
    pushdown(now); 
    int mid=(l+r)>>1;   
    if(L<=mid) REV(l,mid,lson,L,R);  
    if(R>mid)  REV(mid+1,r,rson,L,R);   
    pushup(now);   
}
void ADD(int l,int r,int now,int L,int R,int vx,int vy) {
    if(l>=L&&r<=R) {
        mark_add(now,vx,vy);  
        return; 
    }
    pushdown(now); 
    int mid=(l+r)>>1;   
    if(L<=mid)  ADD(l,mid,lson,L,R,vx,vy);  
    if(R>mid)   ADD(mid+1,r,rson,L,R,vx,vy);  
    pushup(now);  
}     
inline void upd(int x,int y) {         
    while(top[y]!=top[x]) {  
        ADD(1,n,1,dfn[top[y]],dfn[y],-1,0);  
        REV(1,n,1,dfn[top[y]],dfn[y]);     
        ADD(1,n,1,dfn[top[y]],dfn[y],0,1);         
        y=fa[top[y]];   
    }     
    if(y!=x) {
        ADD(1,n,1,dfn[x]+1,dfn[y],-1,0);  
        REV(1,n,1,dfn[x]+1,dfn[y]);   
        ADD(1,n,1,dfn[x]+1,dfn[y],0,1);  
    } 
}
void sol(int st) {
    int x,y,lca;         
    build(1,n,1);   
    for(int i=st;i<=m;i+=2) {       
        if(i+1>m) break;  
        x=a[i],y=a[i+1];        
        if(dep[x]>dep[y]) swap(x,y);        
        lca=get_lca(x,y);      
        mark_add(1,1,0);              
        upd(lca,x); 
        upd(lca,y);   
        (ans+=s[1].sy)%=mod;
    }      
}         
char *p1,*p2,buf[100000];   
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)  
int rd()
{
    int x=0; char c;   
    while(c<48) c=nc();  
    while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc();  
    return x;    
}
int main() {  
    // setIO("input");  
    n=rd(),m=rd();  
    int x,y,z;  
    for(int i=1;i<n;++i) { 
        x=rd(),y=rd(),z=rd();  
        if(z>=mod) z-=mod;  
        add(x,y,z),add(y,x,z);  
    }    
    dfs(1,0);    
    dfs2(1,1);                   
    for(int i=1;i<=m;++i) a[i]=rd();          
    sol(1),sol(2);  
    printf("%lld\n",ans);   
    return 0;   
}

  

上一篇:高薪程序员&面试题精讲系列52之ConcurrentHashMap怎么统计大小?读操作需不需要加锁?


下一篇:ConcurrentHashMap基础原理