关于线段树

首先肯定推荐学姐博客

炒鸡优秀的学姐!

主要是贴代码,并没有什么理论的讲解。

例题:

洛谷p3372[模板]线段树1

洛谷p3373[模板]线段树2

线段树支持单点查询、单点修改、区间查询、区间修改等操作。

基本思想是二分

将线段树节点用一个结构体打包起来

建树:

build(1, 1, n);
void build (int k, int ll, int rr) {
    tr[k].l = ll;
    tr[k].r = rr;
    if (tr[k].l == tr[k].r) {
        scanf ("%d", &tr[k].v);
        return;
    }
    int mid = (tr[k].l + tr[k].r ) >> 1;
    build (lson, ll, mid);
    build (rson, mid + 1, rr);
    tr[k].v = tr[lson].v + tr[rson].v;
}

单点查询:

 

void ask_point(int k) {
    if(tr[k].l == tr[k].r) {ans = tr[k].w;return;}
    mid = (tr[k].l + tr[k].r) >> 1;
    if(x <= mid) ask_point(k << 1);
    else ask_point(k << 1 | 1);
} 

单点修改:

void change_point(int k) {
    if(tr[k].l == tr[k].r) {tr[k].w += k;return;}
    mid = (tr[k].l + tr[k].r) >> 1;
    if(x <= mid) change_point(k << 1);
    else change_point(k << 1 | 1);
    tr[k].w = tr[k << 1].w + tr[k << 1 | 1].w;
} 

区间查询:

void ask_query (int k) {
    if (tr[k].l >= z && tr[k].r <= y) {
        ans += tr[k].v;
        return;
    }
    if (tr[k].f) down (k);
    int m = (tr[k].l + tr[k].r) >> 1;
    if (z <= m) ask (lson);
    if (y > m) ask (rson);
    tr[k].v = tr[lson].v + tr[rson].v;
}

区间修改:

void change_query (int k) {
    if (tr[k].l >= z && tr[k].r <= y) {
        tr[k].v += (tr[k].r - tr[k].l + 1) * w;
        tr[k].f += w;
        return;
    }
    if (tr[k].f) down (k);
    int m = (tr[k].l + tr[k].r) >> 1;
    if (z <= m)  c_v (lson);
    if (y > m) c_v  (rson);
    tr[k].v = tr[lson].v + tr[rson].v;
}

lazy标记:

void down (int k) {
    tr[lson].f += tr[k].f;
    tr[rson].f += tr[k].f;
    tr[lson].v += (tr[lson].r - tr[lson].l + 1) * tr[k].f;
    tr[rson].v += (tr[rson].r - tr[rson].l + 1) * tr[k].f;
    tr[k].f = 0;
}

感性理解

例一代码:

#include <iostream>
#include <cstdio>
//#define int long long
//#define long long int
#define lson k << 1
#define rson k << 1 | 1
using namespace std;
int n, m, z, y, w;
long long ans;
struct node {
    int l, r, f;
    long long v;
}tr[400001];
void build (int k, int ll, int rr) {
    tr[k].l = ll;
    tr[k].r = rr;
    if (tr[k].l == tr[k].r) {
        scanf ("%d", &tr[k].v);
        return;
    }
    int mid = (tr[k].l + tr[k].r ) >> 1;
    build (lson, ll, mid);
    build (rson, mid + 1, rr);
    tr[k].v = tr[lson].v + tr[rson].v;
}
void down (int k) {
    tr[lson].f += tr[k].f;
    tr[rson].f += tr[k].f;
    tr[lson].v += (tr[lson].r - tr[lson].l + 1) * tr[k].f;
    tr[rson].v += (tr[rson].r - tr[rson].l + 1) * tr[k].f;
    tr[k].f = 0;
}
void c_v (int k) {
    if (tr[k].l >= z && tr[k].r <= y) {
        tr[k].v += (tr[k].r - tr[k].l + 1) * w;
        tr[k].f += w;
        return;
    }
    if (tr[k].f) down (k);
    int m = (tr[k].l + tr[k].r) >> 1;
    if (z <= m)  c_v (lson);
    if (y > m) c_v  (rson);
    tr[k].v = tr[lson].v + tr[rson].v;
}
void ask (int k) {
    if (tr[k].l >= z && tr[k].r <= y) {
        ans += tr[k].v;
        return;
    }
    if (tr[k].f) down (k);
    int m = (tr[k].l + tr[k].r) >> 1;
    if (z <= m) ask (lson);
    if (y > m) ask (rson);
    tr[k].v = tr[lson].v + tr[rson].v;
}
int main () {
    scanf ("%d%d", &n, &m);
    build (1, 1, n);
    for (int i = 1; i <= m; i++) {
        int x;
        scanf ("%d", &x);
        if (x == 1) {
            scanf ("%d%d%d", &z, &y, &w);
            c_v (1);
        }
        else if (x == 2) {
            scanf ("%d%d", &z, &y);
            ans = 0;
            ask (1);
            cout << ans << endl;
        }
    } 
    return 0;
}

例二代码:

#include <iostream>
#include <cstdio>
using namespace std;
struct node {
    long long l, r, w, add = 0, mul = 1;
}tr[400000];
long long a, b, ans, y, n, m, p;
int read () {
    long long s = 0;
    int w = 1;
    char ch = getchar ();
    while (!isdigit (ch)) {if (ch == '-') w = -1;ch = getchar ();}
    while (isdigit (ch)) {s = s * 10 + ch - '0';ch = getchar ();}
    return s * w;
}
void build (int k, int ll, int rr) {
    tr[k].l = ll;
    tr[k].r = rr;
    if (ll == rr) {
        tr[k].w = read ();
//        scanf ("%lld", &tr[k].w);
        tr[k].w %= p;
        return;
    }
    int m = (ll + rr) >> 1;
    build (k << 1, ll, m);
    build (k << 1 | 1, m + 1, rr);
    tr[k].w = (tr[k << 1].w + tr[k << 1 | 1].w) % p;
    return;
}
void down (int k) {
    tr[k << 1].mul = tr[k << 1].mul * tr[k].mul % p;
    tr[k << 1 | 1].mul = tr[k << 1 | 1].mul * tr[k].mul % p;
    tr[k << 1].add = (tr[k].mul * tr[k << 1].add + tr[k].add) % p;
    tr[k << 1 | 1].add = (tr[k].mul * tr[k << 1 | 1].add + tr[k].add) % p;
    tr[k << 1].w = (tr[k << 1].w * tr[k].mul % p + tr[k].add * (tr[k << 1].r - tr[k << 1].l + 1) % p) % p;
    tr[k << 1 | 1].w = (tr[k << 1 | 1].w * tr[k].mul % p + tr[k].add * (tr[k << 1 | 1].r - tr[k << 1 | 1].l + 1) % p) % p;
    tr[k].mul = 1;
    tr[k].add = 0;
    return;
}
void mul_interval (int k) {
    if (tr[k].l >= a && tr[k].r <= b) {
        tr[k].w = (tr[k].w * y) % p;
        tr[k].mul = (tr[k].mul * y) % p;
        tr[k].add = (tr[k].add * y) % p;
        return;
    }
    down (k);
    int m = (tr[k].l + tr[k].r) >> 1;
    if (a <= m) mul_interval (k << 1);
    if (b > m) mul_interval (k << 1 | 1);
    tr[k].w =  (tr[k << 1].w + tr[k << 1 | 1].w) % p;
    return;
}
void add_interval (int k) {
    if (tr[k].l >= a && tr[k].r <= b) {
        tr[k].add = tr[k].add + y;
        tr[k].w = (tr[k].w + (tr[k].r - tr[k].l + 1) * y) % p;
        return;
    }
    down (k);
    int m = (tr[k].r + tr[k].l) >> 1;
    if (a <= m) add_interval (k << 1);
    if (b > m) add_interval (k << 1 | 1);
    tr[k].w = (tr[k << 1].w + tr[k << 1 | 1].w) % p;
    return;
}
void ask_interval (int k) {
    if (tr[k].l >= a && tr[k].r <= b) {
        ans += tr[k].w;
        ans %= p;
        return;
    }
    down (k);
    int m = (tr[k].r + tr[k].l) >> 1;
    if (a <= m) ask_interval (k << 1);
    if (b > m) ask_interval (k << 1 | 1);
}
int main () {
    n = read ();
    m = read ();
    p = read ();
//    scanf ("%lld%lld%lld", &n, &m, &p);
    build (1, 1, n);
    for (int i = 1; i <= m; i++) {
        int nu;
        nu = read ();
//        scanf ("%lld", &nu);
        if (nu == 1) {
            a = read ();
            b = read ();
            y = read ();
//            scanf ("%lld%lld%lld", &a, &b, &y);
            mul_interval (1);
        }
        else if (nu == 2) {
            a = read ();
            b = read ();
            y = read ();
//            scanf ("%lld%lld%lld", &a, &b, &y);
            add_interval (1);
        }
        else if (nu == 3) {
            ans = 0;
            a = read ();
            b = read ();
//            scanf ("%lld%lld", &a, &b);
            ask_interval (1);
            cout << ans % p << endl;
        }
    }
    return 0;
}

谢谢收看, 祝身体健康!

上一篇:CF380C Sereja and Brackets 括号序列+线段树


下一篇:蒟蒻的数列[BZOJ4636](线段树)