思路就是对于每一个删除的数,统计在它之前有多少比它大,在它之后有多少比它小。
然后直接想到树套树(实际上是想不到更好的方法),然而树套树极其难写,所以发一个归并树的题解吧。
归并树是对于线段树每一个节点,维护一个从L到R的有序数组(一般用vector实现防止MLE),有点像归并排序的过程。
对于删除操作,我们沿着线段树的查询路径,一路更改遇到的vector,删除对应的数(用二分查找实现是logn),所以一次修改的消耗是log^2(n)。
查询就是线段树的查询方式,把询问区间拆成若干互不相交的子区间,在每个区间内二分查找,然后累加。
PS:因为归并树玄学的常数问题,需要开O2
#include<cstdio> #include<cstdlib> #include<vector> #define maxn 100010 #include<algorithm> #define ll long long #define IT vector<int>::iterator using namespace std; struct node{ int l,r; vector<int> v; } tree[maxn<<2]; int n,m,a[maxn],c[maxn],p[maxn],l1[maxn],l2[maxn]; int lowbit(int x){ return x&(-x); } void change(int x){ while(x<=n){ c[x]++; x+=lowbit(x); } } int query(int x){ int ans=0; while(x){ ans+=c[x]; x-=lowbit(x); } return ans; } void build(int pos,int left,int right){ int mid=left+right>>1; int cnt1=0,cnt2=0; //标准递归线段树 tree[pos].l=left;tree[pos].r=right; if(left==right){ tree[pos].v.push_back(a[left]); return; } build(pos<<1,left,mid); build(pos<<1|1,mid+1,right); //把左右子节点的有序数组 归并起来 for(register IT i=tree[pos<<1].v.begin();i!=tree[pos<<1].v.end();i++) l1[++cnt1]=*i; for(register IT i=tree[pos<<1|1].v.begin();i!=tree[pos<<1|1].v.end();i++) l2[++cnt2]=*i; int i=1,j=1; while(i<=cnt1&&j<=cnt2){ if(l1[i]<l2[j]) tree[pos].v.push_back(l1[i++]); else tree[pos].v.push_back(l2[j++]); } while(i<=cnt1) tree[pos].v.push_back(l1[i++]); while(j<=cnt2) tree[pos].v.push_back(l2[j++]); return; } void modify(int pos,int tar){ int mid=tree[pos].l+tree[pos].r>>1; IT it=lower_bound(tree[pos].v.begin(),tree[pos].v.end(),a[tar]);//找到目标数位置 tree[pos].v.erase(it);//删除 if(tree[pos].l==tree[pos].r) return; if(tar<=mid) modify(pos<<1,tar); else modify(pos<<1|1,tar); return; } int ask(int pos,int left,int right,int key,int type){//type表示要统计比key小的还是比key大的 IT it; int ans=0; //如果当前区间包含于询问区间,二分到第一个大于key的位置 if(tree[pos].l>=left&&tree[pos].r<=right){ it=upper_bound(tree[pos].v.begin(),tree[pos].v.end(),key);//注意是upper_bound if(type==1) return it-tree[pos].v.begin();//在后面统计比它小的 else return tree[pos].v.end()-it;//在前面统计比它大的 } //递归查左右子树 int mid=tree[pos].l+tree[pos].r>>1; if(left<=mid) ans+=ask(pos<<1,left,right,key,type); if(right>mid) ans+=ask(pos<<1|1,left,right,key,type); return ans; } int main(){ int x; ll tot=0; //标准读入 scanf("%d%d",&n,&m); for(register int i=1;i<=n;++i){ scanf("%d",&a[i]); p[a[i]]=i; } //树状数组处理初始逆序对 for(register int i=n;i>=1;i--){ change(a[i]); tot+=query(a[i]-1); } build(1,1,n); printf("%lld\n",tot); for(register int i=1;i<m;i++){ scanf("%d",&x); modify(1,p[x]);//删除 tot-=ask(1,p[x]+1,n,x,1);//减去前后两部分的贡献 tot-=ask(1,1,p[x]-1,x,2); printf("%lld\n",tot); } scanf("%d",&x); return 0; }