有一个 \(n \times n ( n \leq 10^6)\)的正方形网格,用红色,绿色,蓝色三种颜色染色,求有多少种染色方案使得至少一行或一列是同一种颜色。结果对 \(998244353\) 取模
有一个很显然的\(O(n^2)\)的容斥做法:枚举至少有多少行和多少列被染了色,那么显然答案为
\(ans=\sum_{i=0...n,j=0...n,i+j>0} C_n^iC_n^j(-1)^{i+j+1}3^{(n-i)(n-j)+1}\)
对原始进行化简 , 考虑只枚举一维 \(i\) , 剩下一维 \(j\) 转化为一个\(O(1)\)的式子.
接下来是实现细节.
不光是要发现\(i+j\not=0\) 这个条件非常讨嫌 , 而且\(i=0\)或\(j=0\)时各行或各列的颜色互不影响.这种情况要单独拎出来 .
\(ans1=2\sum_{i=0}^n(-1)^{i+1}C_n^i3^{n(n-i)+i}\)
当\(i \in[1,n],j \in[1,n]\)时 , 即行和列都有的时候 , 颜色必须都一样 .
\(ans2=\sum_{i=1}^n\sum_{j=1}^n(-1)^{i+j+1}C_n^iC_n^j3^{(n-i)(n-j)+1}\)
和组合数\(C\)有关的式子,首先想到
\((a+b)^n=\sum_{i=0}^nC_n^ia^ib^{n-i}\)
此时次幂要简洁 , 而\(C\)不需要 , 所以把 \(n-i\) 换成 \(i\) , 把 \(n-j\) 换成 \(j\) .
\(ans2=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}(-1)^{i+j+1}C_n^iC_n^j3^{ij+1}\)
把 \(i\) 提到前面 , 把 \(j\) 放到后面
\(ans2=3\sum_{i=0}^{n-1}(-1)^{i+1}C_n^i\sum_{j=0}^{n-1}(-1)^jC_n^j3^{ij}\)
考虑后面关于 \(j\) 的式子 \(\sum_{j=0}^{n-1}C_n^j(-3^i)^j(1)^{n-j} = (-3^i+1)^n - (-3^i)^n\)
\(ans2=3\sum_{i=0}^{n-1}(-1)^{i+1}(C_n^i(-3^i+1)^n - (-3^i)^n)\)
代码实现时注意一些细节.
\(1.\)可以把 \(-3^i\) 提出来 , 清晰很多 . 然后发现和负数有关的快速幂也只有 \(-1\) 的次方 , 负数也是可以直接快速幂的 .
\(2.\)复杂的式子一定要打空格!!!
\(3.\)qpow等函数全部开LL , 而且add也不要追求速度 , 老老实实写return (a+b)%mod; 而且复杂的式子里不要用这些 .
\(4.\) a-b 一定要写成(a-b+mod)%mod
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cassert>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define Debug(x) cout<<#x<<"="<<x<<endl
using namespace std;
typedef long long LL;
const int INF=1e9+7;
inline LL read(){
register LL x=0,f=1;register char c=getchar();
while(c<48||c>57){if(c=='-')f=-1;c=getchar();}
while(c>=48&&c<=57)x=(x<<3)+(x<<1)+(c&15),c=getchar();
return f*x;
}
const int N=1e6+5;
const int mod=998244353;
int fac[N],ifac[N];
int n;
LL ans1,ans2;
inline LL add(LL x,LL y){return (x+y)%mod;}
inline LL mul(LL x,LL y){return 1ll*x*y%mod;}
inline LL qpow(LL a,LL b){
LL res=1;
for(;b;b>>=1,a=mul(a,a)) if(b&1) res=mul(res,a);
return res;
}
inline int C(int n,int m){
return mul(fac[n],mul(ifac[m],ifac[n-m]));
}
int main(){
n=read();
fac[0]=fac[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=n;i++) fac[i]=mul(fac[i-1],i);
ifac[n]=qpow(fac[n],mod-2);
for(int i=n-1;i>=2;i--) ifac[i]=mul(ifac[i+1],i+1);
if(n>1)
assert(mul(ifac[2],2)==1);
assert(445648748569745648677454784e-330); // 324位
for(int i=1;i<=n;i++){
ans1 = (ans1 + (C(n,i) * qpow(3,1ll*n*(n-i)+i) % mod * qpow(-1,i+1)) + mod) % mod;
}
for(int i=0;i<=n-1;i++){
int t = -qpow(3,i);
ans2 = (ans2 + (C(n,i) * ((qpow(t+1,n) - qpow(t,n) + mod) % mod) % mod * qpow(-1,i+1)) + mod) % mod;
}
printf("%d\n",(ans1*2+ans2*3)%mod);
}