链接:https://www.nowcoder.com/acm/contest/59/F
时间限制:C/C++ 1秒,其他语言2秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld
64bit IO Format: %lld
题目描述
给一个n个点的树,第i个点的值是vi,初始根是1。
有m个操作,每次操作:
1.将树根换为x。
2.给出两个点x,y,求所有点对(a,b)的个数满足a在x子树中,b在y子树中,va==vb
输入描述:
第一行两个数表示n,m
第二行n个数,表示每个点的点权vi
之后n-1行,每行两个数x,y表示一条边
之后m行,每行为:
1 x表示把根换成x点
2 x y表示查询x点的子树与y点的子树
输出描述:
对于每个询问,输出一行一个数表示答案
输入例子:
5 5
1 2 3 4 5
1 2
1 3
3 4
3 5
2 4 5
2 1 5
2 3 5
1 5
2 4 5
输出例子:
0
1
1
1
-->
示例1
输入
5 5
1 2 3 4 5
1 2
1 3
3 4
3 5
2 4 5
2 1 5
2 3 5
1 5
2 4 5
输出
0
1
1
1
备注:
对于100%的数据,1 <= n <= 1e5 , 1<= m <= 5e5 , 1 <= vi<= 1e9
/////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
不是很懂,,一直超时,可能被卡常数了,,但是不想再码了,很绝望,不辜负自己码并且调那么久,还是把自己的超时代码发上来
#include<cstdio>
#include<algorithm>
#define mst(a,b) memset((a),(b), sizeof a)
#define lowbit(a) ((a)&(-a))
#define IOS ios::sync_with_stdio(0);cin.tie(0);
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int mod=1e9+;
const int maxn=1e5+;
const int maxm=8e6;
int n,m;
int val[maxn],mx;
pii use[maxn];
void init_val(){
for(int i=;i<=n;++i)use[i].first=val[i],use[i].second=i;
sort(use+,use++n);
int cc=;
val[use[].second]=;
for(int i=;i<=n;++i){
if(use[i].first==use[i-].first)val[use[i].second]=cc-;
else val[use[i].second]=cc++;
}
}
vector<int>to[maxn];
int dfsn[maxn],cnt; pii mine[maxn][];int v[maxn]; inline void dfs(int pos,int fa){
int le,ri;
dfsn[++cnt]=pos;
le=cnt;
for(int i=;i<to[pos].size();++i)if(to[pos][i]!=fa)dfs(to[pos][i],pos);
ri=cnt;
mine[pos][] = make_pair(le,ri);
} struct node{
int id,a,b;
node(int _id,int _a,int _b){id=_id,a=_a,b=_b;}
};
vector<node>w[maxn];
ll ans[maxn<<],cc;
struct query{
int foo,l,r,bel;bool ti;
query(int a=,bool b=,int c=,int d=):foo(a),ti(b),l(c),r(d){};
};
query all[maxm];int sz; inline void add_q(int l1,int r1,int l2,int r2,int foo){
if(r1&&r2) all[++sz]=query(foo,true,min(r1,r2),max(r1,r2));
if(l1-&&r2) all[++sz]=query(foo,false,min(l1-,r2),max(l1-,r2));
if(l2-&&r1) all[++sz]=query(foo,false,min(l2-,r1),max(l2-,r1));
if(l1-&&l2-) all[++sz]=query(foo,true,min(l1-,l2-),max(l1-,l2-));
} inline void add_query(int a,int b,int foo){
for(int i=;i<v[a];++i)for(int j=;j<v[b];++j)
add_q(mine[a][i].first,mine[a][i].second,mine[b][j].first,mine[b][j].second,foo);
}
inline void change(int from,int to){
v[from]=;
if(v[to]==){
mine[from][v[from]++] = make_pair(mine[to][].second+,mine[to][].first-);
}else{
if(mine[to][].first!=)
mine[from][v[from]++] = make_pair(,mine[to][].first-);
if(mine[to][].second!=n)
mine[from][v[from]++] = make_pair(mine[to][].second+,n); }
v[to]=;
mine[to][] = make_pair(,n);
}
inline void getq(int pos,int fa){
for(int i=;i<w[pos].size();++i)
add_query(w[pos][i].a , w[pos][i].b , w[pos][i].id); for(int i=;i<to[pos].size();++i){
int tt=to[pos][i];
if(tt==fa)continue;
pii g =mine[tt][];
change(pos,tt);
getq(tt,pos);
v[tt]=v[pos]=;
mine[tt][]=g,mine[pos][]=make_pair(,n);
}
}
int blo;
int x[maxm],y[maxm],c[maxm]; int lt[maxn],rt[maxn];
int main() {
#ifdef local
freopen("inpp","r",stdin);
// freopen("outpp","w",stdout);
#endif
scanf("%d%d",&n,&m);
for(int i=;i<=n;++i)scanf("%d",&val[i]),v[i]=;
init_val();
for(int i=;i<n;++i){
int a,b;scanf("%d%d",&a,&b);
to[a].push_back(b);to[b].push_back(a);
}
dfs(,);
int now=;
while(m--){
int od,a,b;
scanf("%d",&od);
if(od==)scanf("%d",&now);
else{
scanf("%d%d",&a,&b);++cc;
w[now].push_back(node(cc,a,b));
}
}
getq(,);
blo=sqrt(sz);if(blo==)blo=;
for(int i=;i<=sz;++i)all[i].bel=all[i].l/blo; for(int i=;i<=sz;++i)++c[all[i].r];//基数排序部分
for(int i=;i<=n;++i)c[i]+=c[i-];
for(int i=;i<=sz;++i)y[c[all[i].r]--]=i; int en=n/blo;
for(int i=;i<=en;++i)c[i]=;
for(int i=;i<=sz;++i)++c[all[i].bel];
for(int i=;i<=en;++i)c[i]+=c[i-];
for(int i=sz;i;--i)x[ c[ all[y[i]].bel ]-- ]=y[i]; ll uu=; int L=,R=; for(int i=;i<=sz;++i){
query&ha=all[x[i]];
while(R<ha.r){
++R;
int k=val[dfsn[R]];
uu-=(ll)rt[k]*lt[k];
++rt[k];
uu+=(ll)rt[k]*lt[k];
}
while(R>ha.r){
int k=val[dfsn[R]];
uu-=(ll)rt[k]*lt[k];
--rt[k];
uu+=(ll)rt[k]*lt[k];
--R;
}
while(L>ha.l){
int k=val[dfsn[L]];
uu-=(ll)rt[k]*lt[k];
--lt[k];
uu+=(ll)rt[k]*lt[k];
--L;
}
while(L<ha.l){
++L;
int k=val[dfsn[L]];
uu-=(ll)rt[k]*lt[k];
++lt[k];
uu+=(ll)rt[k]*lt[k];
}
if(ha.ti)ans[ha.foo]+=uu;
else ans[ha.foo]-=uu;
}
for(int i=;i<=cc;++i)
printf("%lld\n",ans[i]);
return ;
}