大致题意:
给你一段含n个数字的序列,对于这段序列可以有m次操作。
操作有两种类型:
1、(1,L,R)表示将(L,R)区间的每个数加自身的lowbit值(若一个数为x,则其lowbit值为x&-x).
2、(2,L,R)询问区间(L,R)数字之和
思路:
一眼看上去知道要用线段树维护区间信息,但对于操作1对区间每个数都需要加上其lowbit值,似乎直接用单点修改来做会T得很惨,但是仔细一想加的值很特殊是该数的lowbit的值,我们仔细一想便会发现一个数最多加logn次其lowbit值后继续加上lowbit值就变成了乘2(因为此时该数二进制形式上只有最高位为1了)。那么我们对于每个数的单点修改操作最多也只需要进行nlogn次然后题目就变成了一个普通线段树的区间修改和区间求和问题。
所以我们只需要在建造线段树时每个节点加上一个判断是乘2还是加上lowbit的值即可。但因为题目中存在取模操作,对于加上lowbit的情况贸然进行取模可能会导致该数二进制位情况的变化,从而使得结果错误,所以我们需要在当该数还未到达直接乘2的情况时不对其进行取模,直到该数符合乘2的情况我们对该数修改时才进行取模操作
```cpp
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cmath>
using namespace std;
#define int long long
const int N = 1e5 + 5, mod = 998244353;
struct node {
int l, r, sum, flag;
};
int cnt[N], lazy[N * 4];
node tr[N * 4];
int lowbit(int u) { return u & -u; }
bool check(int u) {
if ((u + lowbit(u)) == (u + u)) return true;
else return false;
}
void push_up(node& u, node& l, node& r) {
if (l.flag == 1 && r.flag == 1) u.flag = 1;
else u.flag = 0;
u.sum = (l.sum + r.sum) % mod;
return;
}
void push_up(int u) {
push_up(tr[u], tr[u << 1], tr[u << 1 | 1]);
return;
}
void build(int u, int l, int r) {
if (l == r) {
tr[u] = { l,l,cnt[l] };
if (check(cnt[l])) tr[u].flag = 1;
else tr[u].flag = 0;
return;
}
tr[u] = { l,r };
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
push_up(u);
return;
}
void push_down(int u) {
if (lazy[u] > 1) {
lazy[u << 1] = (lazy[u << 1] * lazy[u]) % mod;
lazy[u << 1 | 1] = (lazy[u << 1 | 1] * lazy[u]) % mod;
tr[u << 1].sum = (tr[u << 1].sum * lazy[u]) % mod;
tr[u << 1 | 1].sum = (tr[u << 1 | 1].sum * lazy[u]) % mod;
lazy[u] = 1;
return;
}
}
void modify(int u, int l, int r) {
if (tr[u].l == tr[u].r) {
if (tr[u].flag == 1) tr[u].sum = (tr[u].sum * 2) % mod;
else {
tr[u].sum = (tr[u].sum + lowbit(tr[u].sum));
if (check(tr[u].sum)) tr[u].flag = 1;
}
return;
}
if (tr[u].l >= l && tr[u].r <= r && tr[u].flag == 1) {
lazy[u] = (lazy[u] * 2) % mod;
tr[u].sum = (tr[u].sum * 2) % mod;
return;
}
push_down(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r);
if (r > mid) modify(u << 1 | 1, l, r);
push_up(u);
return;
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum % mod;
push_down(u);
int ans = 0;
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) ans = (ans + query(u << 1, l, r)) % mod;
if (r > mid) ans = (ans + query(u << 1 | 1, l, r)) % mod;
return ans % mod;
}
signed main() {
int t; scanf("%lld", &t);
while (t--) {
int n; scanf("%lld", &n);
for (int i = 1; i <= n; i++) scanf("%lld", &cnt[i]);
for (int i = 1; i <= 4 * n; i++) lazy[i] = 1, tr[i].sum = 0, tr[i].flag = 0;
build(1, 1, n);
int m; scanf("%lld", &m);
while (m--) {
int op, l, r; scanf("%lld %lld %lld", &op, &l, &r);
if (op == 1)
modify(1, l, r);
else if (op == 2) {
int ans = query(1, l, r) % mod;
printf("%lld\n", ans);
}
}
}
return 0;
}
在这里插入代码片