此题不提供链接
题目描述
前言
这题真的偏水(谁 T M _{TM} TM做对了还来润我的是 S B _{SB} SB),但是我犯了某个板子上的失误,所以值得写一写。
题解
当A中放入第 x x x 个数的时候,对答案的贡献的期望显然是 x x + 1 s \frac{x}{x+1}s x+1xs,其中 s s s 是B集合中元素的平均值。
所以我们可以枚举放入第几个数,然后使用预处理好的逆元直接计算贡献即可。
代码
这个线性逆元预处理,有2个办法:一个是用线性递推逆元的公式:
i
n
v
(
i
)
=
P
−
⌊
P
i
⌋
⋅
i
n
v
(
P
m
o
d
i
)
m
o
d
P
inv(i)=P-\lfloor\frac{P}{i}\rfloor\cdot inv(P\bmod i)\bmod P
inv(i)=P−⌊iP⌋⋅inv(Pmodi)modP
另一个是从组合数预处理中抠出来的,用阶乘的逆元求普通逆元:
i
n
v
(
i
!
)
=
i
n
v
(
(
i
+
1
)
!
)
⋅
(
i
+
1
)
m
o
d
P
i
n
v
(
i
)
=
i
n
v
(
i
!
)
⋅
(
i
−
1
)
!
inv(i!)=inv((i+1)!)\cdot (i+1)\bmod P\\ inv(i)=inv(i!)\cdot (i-1)!
inv(i!)=inv((i+1)!)⋅(i+1)modPinv(i)=inv(i!)⋅(i−1)!
第一种方法更常用,第二种方法非常蠢,但是好记。当你记不住第一种方法又懒得推一遍的时候便可以用第二种方法。
我就是打的第二种。
我代码习惯不好,总是先把0和1的阶乘逆元先打好,再递推其它的。于是当我想要省下一个阶乘数组,在板子上改动的时候,就把1的逆元删掉了。
这导致 n = 1 n=1 n=1 的数据全输出0。
#include<cstdio>//JZM yyds!!
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<ctime>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define ll long long
#define uns unsigned
//#define MOD 998244353ll
#define MAXN 20000005
#define INF 1e17
#define IF (it->first)
#define IS (it->second)
using namespace std;
inline ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0)f^=(s=='-'),s=getchar();
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
return f?x:-x;
}
int pt[30],lp;
inline void print(ll x,char c='\n'){
if(x<0)putchar('-'),x=-x;
pt[lp=1]=x%10;
while(x>9)x/=10,pt[++lp]=x%10;
while(lp)putchar(pt[lp--]^48);
putchar(c);
}
inline ll lowbit(ll x){return x&-x;}
inline ll ksm(ll a,ll b,ll mo){
ll res=1;
for(;b;b>>=1,a=a*a%mo)if(b&1)res=res*a%mo;
return res;
}
const ll MOD=998244353;
ll inv[MAXN];
inline void init(int n){
ll fac=1;
for(int i=2;i<=n;i++)fac=fac*i%MOD;
inv[n]=ksm(fac,MOD-2,MOD);
for(int i=n-1;i>=1;i--)inv[i]=inv[i+1]*(i+1)%MOD;
fac=1;
for(int i=2;i<=n;i++)inv[i]=inv[i]*fac%MOD,fac=fac*i%MOD;
}
int n,m,k;
ll s,ans;
signed main()
{
freopen("mos.in","r",stdin);
freopen("mos.out","w",stdout);
n=read(),m=read(),k=n*m;
init(k+1);
for(int i=1;i<=n;i++)s+=read();
s=s%MOD*inv[n]%MOD;
printf("%lld %d\n",s,k);
for(int i=1;i<=k;i++){
ll ad=s*i%MOD*inv[i+1]%MOD;
ans+=ad;
if(ans>=MOD)ans-=MOD;
}
printf("%lld\n",ans);
return 0;
}