首先肯定推荐学姐博客!
炒鸡优秀的学姐!
主要是贴代码,并没有什么理论的讲解。
例题:
线段树支持单点查询、单点修改、区间查询、区间修改等操作。
基本思想是二分
将线段树节点用一个结构体打包起来
建树:
build(1, 1, n); void build (int k, int ll, int rr) { tr[k].l = ll; tr[k].r = rr; if (tr[k].l == tr[k].r) { scanf ("%d", &tr[k].v); return; } int mid = (tr[k].l + tr[k].r ) >> 1; build (lson, ll, mid); build (rson, mid + 1, rr); tr[k].v = tr[lson].v + tr[rson].v; }
单点查询:
void ask_point(int k) { if(tr[k].l == tr[k].r) {ans = tr[k].w;return;} mid = (tr[k].l + tr[k].r) >> 1; if(x <= mid) ask_point(k << 1); else ask_point(k << 1 | 1); }
单点修改:
void change_point(int k) { if(tr[k].l == tr[k].r) {tr[k].w += k;return;} mid = (tr[k].l + tr[k].r) >> 1; if(x <= mid) change_point(k << 1); else change_point(k << 1 | 1); tr[k].w = tr[k << 1].w + tr[k << 1 | 1].w; }
区间查询:
void ask_query (int k) { if (tr[k].l >= z && tr[k].r <= y) { ans += tr[k].v; return; } if (tr[k].f) down (k); int m = (tr[k].l + tr[k].r) >> 1; if (z <= m) ask (lson); if (y > m) ask (rson); tr[k].v = tr[lson].v + tr[rson].v; }
区间修改:
void change_query (int k) { if (tr[k].l >= z && tr[k].r <= y) { tr[k].v += (tr[k].r - tr[k].l + 1) * w; tr[k].f += w; return; } if (tr[k].f) down (k); int m = (tr[k].l + tr[k].r) >> 1; if (z <= m) c_v (lson); if (y > m) c_v (rson); tr[k].v = tr[lson].v + tr[rson].v; }
lazy标记:
void down (int k) { tr[lson].f += tr[k].f; tr[rson].f += tr[k].f; tr[lson].v += (tr[lson].r - tr[lson].l + 1) * tr[k].f; tr[rson].v += (tr[rson].r - tr[rson].l + 1) * tr[k].f; tr[k].f = 0; }
感性理解
例一代码:
#include <iostream> #include <cstdio> //#define int long long //#define long long int #define lson k << 1 #define rson k << 1 | 1 using namespace std; int n, m, z, y, w; long long ans; struct node { int l, r, f; long long v; }tr[400001]; void build (int k, int ll, int rr) { tr[k].l = ll; tr[k].r = rr; if (tr[k].l == tr[k].r) { scanf ("%d", &tr[k].v); return; } int mid = (tr[k].l + tr[k].r ) >> 1; build (lson, ll, mid); build (rson, mid + 1, rr); tr[k].v = tr[lson].v + tr[rson].v; } void down (int k) { tr[lson].f += tr[k].f; tr[rson].f += tr[k].f; tr[lson].v += (tr[lson].r - tr[lson].l + 1) * tr[k].f; tr[rson].v += (tr[rson].r - tr[rson].l + 1) * tr[k].f; tr[k].f = 0; } void c_v (int k) { if (tr[k].l >= z && tr[k].r <= y) { tr[k].v += (tr[k].r - tr[k].l + 1) * w; tr[k].f += w; return; } if (tr[k].f) down (k); int m = (tr[k].l + tr[k].r) >> 1; if (z <= m) c_v (lson); if (y > m) c_v (rson); tr[k].v = tr[lson].v + tr[rson].v; } void ask (int k) { if (tr[k].l >= z && tr[k].r <= y) { ans += tr[k].v; return; } if (tr[k].f) down (k); int m = (tr[k].l + tr[k].r) >> 1; if (z <= m) ask (lson); if (y > m) ask (rson); tr[k].v = tr[lson].v + tr[rson].v; } int main () { scanf ("%d%d", &n, &m); build (1, 1, n); for (int i = 1; i <= m; i++) { int x; scanf ("%d", &x); if (x == 1) { scanf ("%d%d%d", &z, &y, &w); c_v (1); } else if (x == 2) { scanf ("%d%d", &z, &y); ans = 0; ask (1); cout << ans << endl; } } return 0; }
例二代码:
#include <iostream> #include <cstdio> using namespace std; struct node { long long l, r, w, add = 0, mul = 1; }tr[400000]; long long a, b, ans, y, n, m, p; int read () { long long s = 0; int w = 1; char ch = getchar (); while (!isdigit (ch)) {if (ch == '-') w = -1;ch = getchar ();} while (isdigit (ch)) {s = s * 10 + ch - '0';ch = getchar ();} return s * w; } void build (int k, int ll, int rr) { tr[k].l = ll; tr[k].r = rr; if (ll == rr) { tr[k].w = read (); // scanf ("%lld", &tr[k].w); tr[k].w %= p; return; } int m = (ll + rr) >> 1; build (k << 1, ll, m); build (k << 1 | 1, m + 1, rr); tr[k].w = (tr[k << 1].w + tr[k << 1 | 1].w) % p; return; } void down (int k) { tr[k << 1].mul = tr[k << 1].mul * tr[k].mul % p; tr[k << 1 | 1].mul = tr[k << 1 | 1].mul * tr[k].mul % p; tr[k << 1].add = (tr[k].mul * tr[k << 1].add + tr[k].add) % p; tr[k << 1 | 1].add = (tr[k].mul * tr[k << 1 | 1].add + tr[k].add) % p; tr[k << 1].w = (tr[k << 1].w * tr[k].mul % p + tr[k].add * (tr[k << 1].r - tr[k << 1].l + 1) % p) % p; tr[k << 1 | 1].w = (tr[k << 1 | 1].w * tr[k].mul % p + tr[k].add * (tr[k << 1 | 1].r - tr[k << 1 | 1].l + 1) % p) % p; tr[k].mul = 1; tr[k].add = 0; return; } void mul_interval (int k) { if (tr[k].l >= a && tr[k].r <= b) { tr[k].w = (tr[k].w * y) % p; tr[k].mul = (tr[k].mul * y) % p; tr[k].add = (tr[k].add * y) % p; return; } down (k); int m = (tr[k].l + tr[k].r) >> 1; if (a <= m) mul_interval (k << 1); if (b > m) mul_interval (k << 1 | 1); tr[k].w = (tr[k << 1].w + tr[k << 1 | 1].w) % p; return; } void add_interval (int k) { if (tr[k].l >= a && tr[k].r <= b) { tr[k].add = tr[k].add + y; tr[k].w = (tr[k].w + (tr[k].r - tr[k].l + 1) * y) % p; return; } down (k); int m = (tr[k].r + tr[k].l) >> 1; if (a <= m) add_interval (k << 1); if (b > m) add_interval (k << 1 | 1); tr[k].w = (tr[k << 1].w + tr[k << 1 | 1].w) % p; return; } void ask_interval (int k) { if (tr[k].l >= a && tr[k].r <= b) { ans += tr[k].w; ans %= p; return; } down (k); int m = (tr[k].r + tr[k].l) >> 1; if (a <= m) ask_interval (k << 1); if (b > m) ask_interval (k << 1 | 1); } int main () { n = read (); m = read (); p = read (); // scanf ("%lld%lld%lld", &n, &m, &p); build (1, 1, n); for (int i = 1; i <= m; i++) { int nu; nu = read (); // scanf ("%lld", &nu); if (nu == 1) { a = read (); b = read (); y = read (); // scanf ("%lld%lld%lld", &a, &b, &y); mul_interval (1); } else if (nu == 2) { a = read (); b = read (); y = read (); // scanf ("%lld%lld%lld", &a, &b, &y); add_interval (1); } else if (nu == 3) { ans = 0; a = read (); b = read (); // scanf ("%lld%lld", &a, &b); ask_interval (1); cout << ans % p << endl; } } return 0; }
谢谢收看, 祝身体健康!