原题链接
考察:主席树+树状数组
实际是动态主席树的模板题,反正本蒟蒻不会(.
思路:
主席树实际是有n个根结点的线段树,如果我们修改第i棵主席树的值,后面i~n棵树都需要修改,时间复杂度最坏是\(O(n*m)\)级别的,但是主席树求区间第k小,实际就是求前缀和,而操作又涉及单点修改,这里可以考虑树状数组,n个序列,维护n棵主席树,每棵树都代表树状数组的结点.也就是说,\(root[i]\)维护的不是\([1,i]\),而是\([i-lowbit(i)+1,i]\),每一棵主席树所依赖的上一个版本是它自己
动态开点只需要2倍空间,最坏是\(2(n+m)\),每次树状数组操作\(log_2n\)棵树,每次新开\(log_2n\)个结点,空间开到\(O((n+m)*2+(n+m)log_2^2n)\)
时间复杂度是\(O((n+m)log_2^2n)\)
Code
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 100010,S = 20;
int n, m,a[N],idx,root[N],X[S],Y[S],cntx,cnty;
char op[2];
vector<int> nums;
struct Node{
int l, r, cnt;
Node operator=(const Node& x){
this->l = x.l;
this->r = x.r;
this->cnt = x.cnt;
return *this;
}
} tr[N*582];
struct Query{
int l, r, k,i,t;
bool Q;
} query[N];
int find(int x)
{
return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}
int build(int l,int r)
{
int p = ++idx;
if(l==r)
return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid);
tr[p].r = build(mid + 1, r);
tr[p].cnt = 0;
return p;
}
int insert(int last,int l,int r,int val,int x)
{
int p = ++idx;
tr[p] = tr[last];
if(l==r)
{
tr[p].cnt+=x;
return p;
}
int mid = l + r >> 1;
if(val<=mid)
tr[p].l = insert(tr[last].l, l, mid, val,x);
else
tr[p].r = insert(tr[last].r, mid + 1, r, val,x);
tr[p].cnt = tr[tr[p].l].cnt+tr[tr[p].r].cnt;
return p;
}
int lowbit(int x)
{
return x & -x;
}
void add(int k,int v)
{
int x = find(a[k]);
for (int i = k; i <= n;i+=lowbit(i))
root[i] = insert(root[i], 0, nums.size()-1, x, v);
}
int ask(int l,int r,int k)
{
if(l==r)
return nums[l];
int sum = 0;
for (int i = 1; i <= cntx;i++)
sum -= tr[tr[X[i]].l].cnt;
for (int i = 1; i <= cnty;i++)
sum += tr[tr[Y[i]].l].cnt;
int mid = l + r >> 1;
if(k<=sum)
{
for (int i = 1; i <= cntx;i++)
X[i] = tr[X[i]].l;
for (int i = 1; i <= cnty;i++)
Y[i] = tr[Y[i]].l;
return ask(l, mid, k);
}
else{
for (int i = 1; i <= cntx;i++)
X[i] = tr[X[i]].r;
for (int i = 1; i <= cnty;i++)
Y[i] = tr[Y[i]].r;
return ask(mid + 1, r, k - sum);
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n;i++)
scanf("%d", &a[i]), nums.push_back(a[i]);
for (int i = 1; i <= m;i++)
{
scanf("%s", op);
if(op[0]=='Q')
{
scanf("%d%d%d", &query[i].l, &query[i].r, &query[i].k);
query[i].Q = 1;
continue;
}
scanf("%d%d", &query[i].i, &query[i].t);
nums.push_back(query[i].t);
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
root[0] = build(0, nums.size() - 1);
for(int i=1;i<=n;i++) root[i] = 1;
for (int i = 1; i <= n;i++) add(i, 1);
for (int i = 1; i <= m;i++)
{
if(query[i].Q)
{
cntx = 0,cnty = 0;
for (int j = query[i].l-1; j;j-=lowbit(j))
X[++cntx] = root[j];
for (int j = query[i].r; j;j-=lowbit(j))
Y[++cnty] = root[j];
printf("%d\n", ask(0,nums.size()-1, query[i].k));
continue;
}
add(query[i].i, -1);
a[query[i].i] = query[i].t;
add(query[i].i, 1);
}
return 0;
}