Description
Solution
标签里就俩东西,没错就是他们了。
观察到 4,5,6 操作中需要用到一个常数 \(v\),所以我们的矩阵得开到 1 * 4
,存 \((A, B, C, 1)\)。
至于转移的 6 个矩阵这里就不推了,比较基础(不会真的有人来做这道题了连矩阵都不会推吧)。
简单说下线段树上矩阵乘法的含义。
线段树上的每个节点都维护一个矩阵。
叶子节点就是每个矩阵本身。
其他的点就是所表示区间内的所有矩阵的对应位的和。
-
对于更新操作:
跟普通的线段树一样,就是把一个数乘上一个数改成了一个矩阵乘上一个矩阵,\(lazy\) 标记什么的也都一样更新。
注意:矩阵乘法没有交换律,可别瞎乘了。
-
对于询问:
也跟普通的线段树是一样的。真的一点区别都没有。
但是本蒟蒻在一开始写的时候觉得有一点点无从下手,这里记录一下我的思考历程。
Q:这 6 个操作怎么用矩阵乘法维护啊?
A:这个简单,随便推推矩阵就好了,矩阵中是应该是 v
的地方先赋成 0,需要用的时候再给他赋上 v
。
Q:那这 6 个矩阵我怎么存下来呢,执行操作的时候我该怎么调用数组呢?
A:不要怕麻烦!直接开 6 个 4 * 4
的数组就完了!调用的时候判断该用哪个用哪个。
Q:我线段树 build
的时候初值怎么赋呢?
A:初值就是输入的东西,对于每一个叶子节点把输入的数直接扔进去就好了。
Q:我需要用到线段树的哪些操作呢?
A:在线段树上矩阵乘法一般都只需要区间乘法,这题也不例外,所以我们只需要维护一个乘法的懒惰标记。在 \(update\) 函数中,直接传一个矩阵变量即可。
再讲讲卡常:
- 最为重要: 矩阵乘法的最内层循环使用循环展开,加法取模改成减法。
- 使用快读。
- 不要开 \(long \ long\),乘法的时候强制类型转换(不会真的有人全开 \(long \ long\) 了吧)。
- 如果你不怕麻烦的话,不要使用重载的乘法运算符,因为那样是
4 * 4
的两个矩阵相乘,而你只需要1 * 4
\(\times\)4 * 4
。
其实只要有第一条和第三条基本上就能卡过去了。
下面就是你们最想看到的代码了 \(:)\)
Code
#include <bits/stdc++.h>
using namespace std;
namespace IO{
inline int read(){
int x = 0;
char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x;
}
template <typename T> inline void write(T x){
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
}
using namespace IO;
const int N = 3e5 + 10;
const int mod = 998244353;
int n, m;
inline int add(int x) {return x > mod ? x - mod : x;}
struct matrix{
int num[4][4];
matrix() {memset(num, 0, sizeof(num));}
inline void clear() {memset(num, 0, sizeof(num));}
inline void init() {num[0][0] = num[1][1] = num[2][2] = num[3][3] = 1;}
matrix(const int a[][4]){
for(int i = 0; i < 4; ++i)
for(int j = 0; j < 4; ++j)
num[i][j] = a[i][j];
}
matrix operator + (const matrix &b) const{
matrix r;
for(int i = 0; i < 4; ++i)
for(int j = 0; j < 4; ++j)
r.num[i][j] = (1ll * num[i][j] + b.num[i][j]) % mod;
return r;
}
matrix operator * (const matrix &b) const{
matrix r;
for(int i = 0; i < 4; ++i)
for(int j = 0; j < 4; ++j){
r.num[i][j] = add(r.num[i][j] + (1ll * num[i][0] * b.num[0][j]) % mod);
r.num[i][j] = add(r.num[i][j] + (1ll * num[i][1] * b.num[1][j]) % mod);
r.num[i][j] = add(r.num[i][j] + (1ll * num[i][2] * b.num[2][j]) % mod);
r.num[i][j] = add(r.num[i][j] + (1ll * num[i][3] * b.num[3][j]) % mod);
}
return r;
}
}a[N];
const int _A[4][4] = {
1, 0, 0, 0,
1, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1
};
matrix A(_A);
const int _B[4][4] = {
1, 0, 0, 0,
0, 1, 0, 0,
0, 1, 1, 0,
0, 0, 0, 1
};
matrix B(_B);
const int _C[4][4] = {
1, 0, 1, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1
};
matrix C(_C);
const int _D[4][4] = {
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1 // [3][0] = v
};
matrix D(_D);
const int _E[4][4] = {
1, 0, 0, 0,
0, 0, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1 // [1][1] = v
};
matrix E(_E);
const int _F[4][4] = {
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 0, 0,
0, 0, 0, 1 // [3][2] = v
};
matrix F(_F);
#define ls rt << 1
#define rs rt << 1 | 1
matrix sum[N << 2], lazy[N << 2];
inline void pushup(int rt){
sum[rt] = sum[ls] + sum[rs];
}
inline void pushdown(int rt){
sum[ls] = sum[ls] * lazy[rt];
sum[rs] = sum[rs] * lazy[rt];
lazy[ls] = lazy[ls] * lazy[rt];
lazy[rs] = lazy[rs] * lazy[rt];
lazy[rt].clear(), lazy[rt].init();
}
inline void build(int l, int r, int rt){
lazy[rt].init();
if(l == r){
sum[rt] = a[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
pushup(rt);
}
inline void update(int L, int R, matrix k, int l, int r, int rt){
if(L <= l && r <= R){
sum[rt] = sum[rt] * k;
lazy[rt] = lazy[rt] * k;
return;
}
pushdown(rt);
int mid = (l + r) >> 1;
if(L <= mid) update(L, R, k, l, mid, ls);
if(R > mid) update(L, R, k, mid + 1, r, rs);
pushup(rt);
}
inline matrix query(int L, int R, int l, int r, int rt){
if(L <= l && r <= R) return sum[rt];
pushdown(rt);
int mid = (l + r) >> 1;
matrix res;
if(L <= mid) res = res + query(L, R, l, mid, ls);
if(R > mid) res = res + query(L, R, mid + 1, r, rs);
return res;
}
int main(){
n = read();
for(int i = 1; i <= n; ++i)
a[i].num[0][0] = read(), a[i].num[0][1] = read(), a[i].num[0][2] = read(), a[i].num[0][3] = 1;
build(1, n, 1);
m = read();
while(m--){
int op = read(), l = read(), r = read();
if(op == 1) update(l, r, A, 1, n, 1);
else if(op == 2) update(l, r, B, 1, n, 1);
else if(op == 3) update(l, r, C, 1, n, 1);
else if(op == 4) D.num[3][0] = read(), update(l, r, D, 1, n, 1);
else if(op == 5) E.num[1][1] = read(), update(l, r, E, 1, n, 1);
else if(op == 6) F.num[3][2] = read(), update(l, r, F, 1, n, 1);
else{
matrix ans = query(l, r, 1, n, 1);
printf("%d %d %d\n", ans.num[0][0], ans.num[0][1], ans.num[0][2]);
}
}
return 0;
}
\[\_EOF\_
\]