好题。。写了两个半小时hh,省选的时候要一个半小时内调出这种题目还真是难= =
题目大意是给一棵树或环套树,求点距大于等于K的点对数
这里的树状数组做了一点变换。不是向上更新和向下求和,而是反过来,所以求和的时候sum(k)实际上是求k到n的和
所以我们要求大于等于k的dis的次数和,就是求sum(1,k-1),注意k要减一
如果是树,就是常规的点分治,然后用树状数组维护dis【t】出现的次数
如果是环套树,找环之后割掉一条边,然后先求这棵树的答案。接着考虑过了这条割掉的边s--t的情况:我们以这条边的一点t为起点,对于环上的每个点(即每棵子树的根),我们求出这棵子树的所有dis后,dis+cir_len-i为所求链的第一部分,链的第二部分的长度为k-(dis+cir_len-i),用树状数组求就可以了。更新树状数组的时候不是更新dis,而是dis+i;i即根到割的那条边的另一个点s的距离&&这条割边
完美解决。。然而常数还是很大,跑了两秒多
#include<stdio.h> #include<string.h> #include<algorithm> #define INF 0x3f3f3f3f #define LL long long using namespace std; ; struct node{ int to,next; }e[maxn*]; int n,m,K,head[maxn],size[maxn],vis[maxn],sz,total,root,dis[maxn],tot,fa[maxn]; ; LL p[maxn*],ans; void insert(int u, int v){ e[++tot].to=v; e[tot].next=head[u]; head[u]=tot; } void add(int x, LL c){ for (;x;x-=x&-x) p[x]+=c; } LL query(int x){ //注意:这里的树状数组是倒过来的, query(1,k) 是求得k+1到n LL ret=; ) x=; *n;x+=x&-x) ret+=p[x]; return ret; } void getroot(int u, int f){ size[u]=; ; for (int v,i=head[u]; i; i=e[i].next){ if (vis[v=e[i].to] || v==f || i==ban1 || i==ban2) continue; getroot(v,u); size[u]+=size[v]; mx=max(mx,size[v]); } mx=max(mx,total-size[u]); if (mx<sz) sz=mx,root=u; } void getdis(int u, int f, int d){ dis[++tot]=d; for (int i=head[u],v; i; i=e[i].next){ if (vis[v=e[i].to] || v==f || i==ban1 || i==ban2) continue; getdis(v,u,d+); } } void work(int u){ total=size[u]?size[u]:n; sz=INF; getroot(u,); u=root; vis[u]=; tot=; ; i; i=e[i].next){ if (vis[v=e[i].to] || i==ban1 || i==ban2) continue; last=tot; getdis(v,,); //printf("%d\n", tot); ; j<=tot; j++) ans+=query(K--dis[j]); ; j<=tot; j++) add(dis[j],); } ans+=query(K-); ); for (int v,i=head[u]; i; i=e[i].next) if (!vis[v=e[i].to] && i!=ban1 && i!=ban2) work(v); } void find_cir(int u, int f){ vis[u]=; if (len) return;//printf(" %d\n", u); for (int i=head[u],v; i; i=e[i].next){ v=e[i].to; if (v==f || len) continue; fa[v]=u;// printf("now %d\n", u); if (vis[v]){ ban1=i; ban2=i^; for (int x=fa[v]; x!=v; x=fa[x]) cir[++len]=x; cir[++len]=v; return; } find_cir(v,u); } } void cut(){ ; i<=n; i++) vis[i]=; work();// printf(" %lld\n", ans); ; i<=n; i++) p[i]=0LL,vis[i]=; ; i<=len; i++) vis[cir[i]]=; ; i<=len; i++){ ; getdis(u,,); //printf(" %d\n", tot); ; j<=tot; j++) ans+=query(K-dis[j]-(len-i+));//, printf("%lld\n", ans); ); } } int main(){ scanf(; ,u,v; i<=m; i++){ scanf("%d%d", &u, &v); insert(u,v); insert(v,u); } ) work(); else{ find_cir(,); cut(); } printf("%lld\n", ans); ; }