题目大意
有 \(M\) 个球,一开始每个球均有一个初始标号,标号范围为 \(1\) ~ \(N\) 且为整数,标号为 \(i\) 的球有 \(a_i\) 个,并保证 \(\sum a_i = M\)。
每次操作等概率取出一个球(即取出每个球的概率均为 \(1\over M\)),若这个球标号为 \(k\ (k < N)\),则将它重新标号为 \(k+1\);若这个球标号为 \(N\),则将其重标号为 \(1\)。(取出球后并不将其丢弃)
现在你需要求出,经过 \(K\) 次这样的操作后,每个标号的球的期望个数。
数据范围
\(N ≤ 1000, M ≤ 100,000,000, K ≤ 2,147,483,647\)。
思路
第一次见到循环矩阵优化 dp 的套路,记录一下。
转移方程很好得到,设 \(f[i][j]\) 表示到第 \(i\) 轮 \(j\) 编号的球的期望个数,转移方程就是
\[f[i][j]=\cfrac{m-1}{m}\ f[i-1][j]+\cfrac{1}{m}\ f[i-1][j-1]\ (2\leq j\leq n) \] \[f[i][1]=\cfrac{m-1}{m}\ f[i-1][1]+\cfrac{1}{m}\ f[i-1][n] \]通过 \(K\) 的范围的提示,我们冲一个矩阵快速幂即可,时间效率 \(O(n^3\log K)\)
\(n\leq 1000\)
那没事了。
假设 \(n=4\),我们构造出转移矩阵:
\[ \left[ \begin{matrix} f[i-1][1] & f[i-1][2] & f[i-1][3] & f[i-1][4] \end{matrix} \right] \times \left[ \begin{matrix} \cfrac{m-1}{m} & \cfrac{1}{m} & 0 & 0 \\ 0 & \cfrac{m-1}{m} & \cfrac{1}{m} & 0 \\ 0 & 0 & \cfrac{m-1}{m} & \cfrac{1}{m} \\ \cfrac{1}{m} & 0 & 0 & \cfrac{m-1}{m} \end{matrix} \right] = \left[ \begin{matrix} f[i][1] & f[i][2] & f[i][3] & f[i][4] \end{matrix} \right] \]我们发现转移矩阵是一个循环矩阵。
那么这个矩阵满足什么性质呢?
我们设第一排的第 \(i\) 个数为 \(k[i]\),我们以 \(k[1]\) 为例:
\[k[1]=a[1][1]\times a[1][1]+a[1][2]\times a[2][1]+a[1][3]\times a[3][1]+a[1][4]\times a[4][1] \]我们将其对应到第一行的元素,得到:
\[k[1]=k[1]\times k[1]+k[2]\times k[4]+k[3]\times k[3]+k[4]\times k[2] \]很容易看出性质:
\[k[t]=\sum\limits_{(i+j-2)\equiv t\pmod n}k[i]\times k[j] \]所以我们只需要记录第一行的状态,用 \(O(n^2\log K)\)转移即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn=1000+10;
int n,m,K;
struct Mat{
double a[maxn];
Mat(){
memset(a,0,sizeof(a));
}
friend inline Mat operator *(register const Mat& A,register const Mat& B){
Mat C;
for(register int i=1;i<=n;i++)
for(register int j=1;j<=n;j++)
C.a[(i+j-2)%n+1]+=A.a[i]*B.a[j];
return C;
}
}ans,base;
inline int read(){
int x=0;bool fopt=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')fopt=0;
for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-48;
return fopt?x:-x;
}
inline void qpow(int b){
while(b){
if(b&1)ans=ans*base;
base=base*base;
b>>=1;
}
}
int main(){
n=read();m=read();K=read();
for(int i=1;i<=n;i++)
ans.a[i]=read();
base.a[1]=1.0*(m-1)/m;
base.a[2]=1.0/m;
qpow(K);
for(int i=1;i<=n;i++)
printf("%.3lf\n",ans.a[i]);
return 0;
}