ICPC2021网络赛第二场 K Meals 概率 dp

第\(i\)个人能选\(j\)​个物品,当且仅当前\(i-1\)个人没选第\(j\)个物品,且\(i\)的喜好序列中,第\(j\)个物品是未被选走的里排最前面的。

第\(i\)个人喜好序列中第\(j\)个物品排在任意一个位置的概率均为\(a_{ij}/\sum_{k=1}^{n}a_{ik}\),

前\(i\)​个人选的物品的集合为\(S\)​​的概率,记为\(p(i,S)\)​,则第\(i\)​个人选第\(j\)​个物品的概率为

\[\sum_{S,S中有i-1个1}[j\notin S]\times p(i-1, S)\times \frac{a[i][j]}{\sum_{k=1}^na[i][k]\times [k \notin S]} \]

,知道\(p\)​之后后面的东西可以\(O(2^n n)\)​​​算,

\[p(i,S) = \sum_{j=1}^{n}[j\in S,S中有i个1]\times p(i-1,S-(1<<(j-1)))\times \frac{a[i][j]}{a[i][j]+\sum_{k=1}^na[i][k]\times [k \notin S]} \]

数组开小wa半天,没有线性求逆元会tle

#include<bits/stdc++.h>
using namespace std;
const int maxn = 20 + 7;
#define ll long long
const ll md = 998244353;
ll n, a[maxn][maxn], sum[maxn][maxn], p[maxn][1048577], ans[maxn][maxn], cnt[1048577]; 
ll rd() {
	ll s = 0, f = 1; char c = getchar();
	while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
	while (c >= '0' && c <= '9') {s = s * 10 + c - '0'; c = getchar();}
	return s * f;
}
ll ksm(ll a, ll b) {
	ll res = 1;
	while (b) {
		if (b & 1ll) res = res * a % md;
		a = a * a % md;
		b >>= 1ll;
	}
	return res;
}
ll inv[2007];
int main() {
	n = rd();
	for (int i = 1; i <= n; i++) {
		for (int j = 1; j <= n; j++) {
			a[i][j] = rd();
			//sum[i][j] = (sum[i][j-1] + a[i][j]) % md;
		}
	}
	inv[0] = inv[1] = 1;
	for (int i = 2; i <= 2000; i++) 
		inv[i] = inv[md % i] * (md - md / i) % md;
	//for (int i = 1; i <= n; i++) ans[1][i] = a[1][i] * ksm(sum[1][n], md - 2) % md;
	//p[i][S]表示前i个人已选集合为S的概率。
	for (int S = 1; S < (1 << n); S++) {
		if (S & 1) cnt[S] = cnt[S>>1] + 1;
		else cnt[S] = cnt[S>>1];
		//p[0][S] = 1;
	}
	p[0][0] = 1;
	for (int i = 1; i <= n; i++) {
		for (int S = 0; S < (1 << n); S++) {
			ll sum0 = 0;
			for (int j = 0; j < n; j++) {
				if (S & (1 << j)) continue;
				sum0 = (sum0 + a[i][j+1]) % md;
			}
			if (cnt[S] == i-1) {
				for (int j = 1; j <= n; j++) {
					if ((1 << (j-1)) & S) continue;
						ans[i][j] = (ans[i][j] + p[i-1][S] * a[i][j] % md * inv[sum0] % md) % md;
				}
			}
			if (cnt[S] == i) {
				for (int j = 1; j <= n; j++) {
					if (S & (1 << (j-1))) {						
						p[i][S] = (p[i][S] + p[i-1][S^(1<<(j-1))]*a[i][j]%md*inv[sum0+a[i][j]]%md)%md;
					}
				}
			}
		}
	}
	for (int i = 1; i <= n; i++) {
		for (int j = 1; j <= n; j++) {
			printf("%lld", ans[i][j]);
			if (j != n) printf(" ");
		}
		if (i != n)
			printf("\n");
	}
	return 0;
}
上一篇:html和css基础


下一篇:3.sql进阶文档(知识点)