线段树进阶 —— 权值线段树与动态开点

前置知识:线段树

权值线段树

我们都知道,普通的线段树是一种用来维护序列区间最值的一种数据结构。

而权值线段树,就是将序列中各个值出现的频数作为权值,再用线段树来维护值域的数据结构。

与其说是一种数据结构,更不如说是一个线段树的 trick。

实际代码的话,用线段树维护桶数组就行了。

权值线段树重要作用是反应序列中元素的大小问题,
如求第k大第k小问题。

因为本身就与普通线段树差不多,所以就直接放模板了((

code:

#include <bits/stdc++.h>
using namespace std;

const int N=3e5+10;

int n,k_;
int bucket_[N];

int tree[N<<1];

void push_up(int node)
{
    tree[node]=tree[node<<1]+tree[node<<1|1];
}

void build(int node,int start,int end)
{
    if(start==end)
    {
        tree[node]=bucket_[start];
        return ;
    }
    int mid=start+end>>1;
    int lnode=node<<1;
    int rnode=node<<1|1;
    build(lnode,start,mid);
    build(rnode,mid+1,end);
    push_up(node);
}

void update(int node,int start,int end,int k,int val)//大小为val的数多k个,相当于单点修改
{
    if(start==end)
    {
        tree[node]+=k;
        return ;
    }
    int mid=start+end>>1;
    int lnode=node<<1;
    int rnode=node<<1|1;
    if(val<=mid) update(lnode,start,mid,k,val);
    else update(rnode,mid+1,end,k,val);

    push_up(node);
}

int query(int node,int start,int end,int val)//查询数字val有多少个,相当于单点查询
{
    if(start==end) return tree[node];
    int mid=start+end>>1;
    int lnode=node<<1;
    int rnode=node<<1|1;
    if(val<=mid) return query(lnode,start,mid,val);
    else return query(rnode,mid+1,end,val);
    push_up(node);
}

int query_kth(int node,int start,int end,int k)//查询第k小
{
    if(start==end) return start;
    int mid=start+end>>1;
    int lnode=node<<1;
    int rnode=node<<1|1;
    if(tree[lnode]>=k) return query_kth(lnode,start,mid,k);//如果左子树的权值大于k,证明第k小值左子树
    else return query_kth(rnode,mid+1,end,k-tree[lnode]);//进入右子树时,整个区间的第k小相当于右区间的第(k-左区间)小,记得减去左子树的值
}

int query_kthbig(int node,int start,int end,int k)//查询第k大
{
    if(start==end) return start;
    int mid=start+end>>1;
    int lnode=node<<1;
    int rnode=node<<1|1;
    if(tree[rnode]>=k) return query_kthbig(rnode,mid+1,end,k);//若右子树的权值大于k,证明第k大值在右子树
    else return query_kthbig(lnode,start,mid,k-tree[rnode]);//进入左子树时,记得减去右区间
}

int main()
{
    scanf("%d%d",&n,&k_);
    int maxn=0,tot=0;
    for(int i=1; i<=n; i++)
    {
        int x;
        scanf("%d",&x);
        maxn=maxn>=x?maxn:x;
        bucket_[x]++;
    }
    build(1,1,maxn);
    cout<<query_kth(1,1,maxn,k_);
    return 0;
}

板子题:

求k小整数

由于重复的不算,用桶统计出现次数的时候只统计到1就行了。

例题:

逆序对

值域\(10^9\),显然不是数组能直接开下的。

一看 \(n\) 的大小,才 \(5 \times 10^5\),可以考虑离散化。

离散化是很常用的缩小值域的方法,好写好用。

缺点也很显著:只能离线。

完成离散化后,我们就可以心安理得地建立一棵空的权值线段树了。

按顺序遍历离散化后数组,每遍历到一个值 \(x\) ,先在线段树中查询有多少大于 \(x\) 的数,加入答案数中,然后将当前值插入线段树中。

code:码风略毒

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=5e5+10;

struct Num
{
    int no__,val__;
} arr[N];//用结构体的方式存原数组,方便离散化

bool cmp(Num x,Num y)
{
    if(x.val__==y.val__) return x.no__<y.no__;
    return x.val__<y.val__;
}

int n;
int poi[N];//离散化之后的数组
ll tree[N<<2];

void push_up(int node)
{
    tree[node]=tree[node<<1]+tree[node<<1|1];
}
void build(int node,int start,int end)
{
    if(start==end)
    {
        tree[node]=0;//先建空树
        return ;
    }
    int mid=start+end>>1;
    int lnode=node<<1;
    int rnode=node<<1|1;
    build(lnode,start,mid);
    build(rnode,mid+1,end);

    push_up(node);

}

void update(int node,int start,int end,int k,int val)//单点修改
{
    if(start==end)
    {
        tree[node]=(ll)tree[node]+k;
        return;
    }
    int mid=start+end>>1;
    int lnode=node<<1;
    int rnode=node<<1|1;
    if(val<=mid) update(lnode,start,mid,k,val);
    else update(rnode,mid+1,end,k,val);

    push_up(node);
}

ll query(int node,int start,int end,int l,int r)//查询值为[l,r]之间数的出现次数和
{
    if(end<l||start>r) return 0;
    if(l<=start&&end<=r) return tree[node];

    ll sum=0;
    int mid=start+end>>1;
    int lnode=node<<1;
    int rnode=node<<1|1;
    if(l<=mid) sum+=query(lnode,start,mid,l,r);
    if(r>mid) sum+=query(rnode,mid+1,end,l,r);

    push_up(node);
    return sum;
}

int main()
{
    scanf("%d",&n);
    for(int i=1; i<=n; i++)
    {
        scanf("%d",&arr[i].val__);
        arr[i].no__=i;
    }

    sort(arr+1,arr+1+n,cmp);
    for(int i=1; i<=n; i++)
        poi[arr[i].no__]=i;//离散化

    ll ans=0;
    for(int i=1; i<=n; i++)
    {
        int x=poi[i];
        ans+=query(1,1,n,x+1,n);
        update(1,1,n,1,x);
    }
    printf("%lld",ans);
    return 0;
}

动态开点

刚刚我们说了,线段树数组维护区间远大于实际使用时,可以试着利用离散化的方式缩小值域。

但是离散化只能用于离线算法,如果题目不让你离线怎么办呢

于是就有了动态开点。

动态开点,正如其名,用多少点我们就开多少点。

然而在动态开点线段树中,存储必须以结构体形式。因为动态开点的点的下标是我们人为规定。同时在传节点编号时也要传引用。

动态开点的修改函数与查询函数也有所改动:

#define lnode tree[node].lson
#define rnode tree[node].rson

void update(int &node,int start,int end,int k)//node是人为规定的编号,所以传引用·
{
    if(!node)//新建节点
    {
        node=++tot;
        lnode=rnode=tree[node].sum=0;
    }
    if(start==end)
    {
        tree[node].sum++;
        return ;
    }
    int mid=start+end>>1;
    if(k<=mid) update(lnode,start,mid,k);
    else update(rnode,mid+1,end,k);

    push_up(node);
}

ll query(int node,int start,int end,int l,int r)
{
    if(!node) return 0;//这个节点未被创建, 返回0
    if(l<=start&&end<=r) return tree[node].sum;

    ll sum=0;
    int mid=start+end>>1;
    if(l<=mid) sum+=query(lnode,start,mid,l,r);
    if(r>mid) sum+=query(rnode,mid+1,end,l,r);

    return sum;
}

但是整体的框架仍然是普通线段树。

求逆序对(动态开点权值线段树)code:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=5e5+10;
const int MAX=1e9;

int n;
struct node
{
    ll sum;
    int lson,rson;
} tree[N<<5];
int tot=0;//计数变量, 记录开了多少节点

#define lnode tree[node].lson
#define rnode tree[node].rson

void push_up(int node)
{
    tree[node].sum=tree[lnode].sum+tree[rnode].sum;
}

void update(int &node,int start,int end,int k)//node是人为规定的编号,所以传引用·
{
    if(!node)//新建节点
    {
        node=++tot;
        lnode=rnode=tree[node].sum=0;
    }
    if(start==end)
    {
        tree[node].sum++;
        return ;
    }
    int mid=start+end>>1;
    if(k<=mid) update(lnode,start,mid,k);
    else update(rnode,mid+1,end,k);

    push_up(node);
}

ll query(int node,int start,int end,int l,int r)
{
    if(!node) return 0;//这个节点未被创建, 返回0
    if(l<=start&&end<=r) return tree[node].sum;

    ll sum=0;
    int mid=start+end>>1;
    if(l<=mid) sum+=query(lnode,start,mid,l,r);
    if(r>mid) sum+=query(rnode,mid+1,end,l,r);

    return sum;
}

int main()
{
    scanf("%d",&n);
    ll ans=0;
    int root=1;
    tot=1;
    for(int i=1; i<=n; i++)
    {
        int x;
        scanf("%d",&x);
        if(x+1<=MAX) ans+=query(1,1,MAX,x+1,MAX);
        update(root,1,MAX,x);
    }
    printf("%lld\n",ans);
    return 0;
}

\(71\) 行,还挺短

上一篇:链表必学算法(三):归并法


下一篇:单链表的创建