【模板】二分图最大权完美匹配
题目链接:luogu P6577
题目大意
一个二分图,有一些带权边,保证有完美匹配。
求一种最大匹配的方案使得匹配边的边权和最大。
思路
KM 算法的模板题。
它有一定的针对性:一定要是带权的完美匹配。
然后我们定义每个点有一个顶表(一个值)\(e_x\)。
对于每一条边 \((u,v)\),我们要满足 \(e_u+e_v\geqslant w(u,v)\)。
然后如果 \(e_u+e_v=w(u,v)\),那我们就可以把这条边放进二分图中。
那如果这个时候的图能跑出完美匹配,那此时的边权和即是结果。
但是你发现定标的值不能直接确定,所以算法的流程是这样的:
确定定标的值,跑匹配判断(如果匹配是 \(n\) 就结束),否则就修改定标,然后重复操作。
然后接着问题就是一开始怎么给定标的值,而且每次怎么修改。
一开始给定标,我们可以让 \(ex_x=0,ey_x=\max\limits_{i=1}^nw(x,i)\)。
然后每次修改,我们就是要减少一些 \(e_u+e_v\) 的值使得更多边在图中。
那我们修改就是找一条边 \((i,j)\),它一个不在最大匹配,一个在。
那我们要让他加入,我们就要让他满足条件,那定标和要减少 \(d_i+d_j-w(i,j)\)。
那因为 \(j\) 已经在最大匹配中了,所以我们就直接把二分图最大匹配中的任意点 \(i\) 都把 \(ex_i+d\) 或者把 \(ey_i-d\)。
那我们为了满足 \(e_u+e_v\geqslant w(u,v)\),所以我们要每次的 \(d\) 尽量小。
那每次找边复杂度 \(O(n^2)\),二分图匹配的复杂度是 \(O(n^2)\),总的复杂度为 \(O(n^4)\)。
然后发现每次都暴力找 \(d\) 太慢了,我们考虑用一个数组 \(slack_j\) 表示 \(ex_i+ey_i-w(i,j)\) 的最小值,然后在跑增广路的时候修改即可做到 \(O(n^3)\)。
吗?
其实会假。因为如果你匹配的部分跑到 \(O(n^2)\),它还是 \(O(n^4)\)。
然后我们发现我们每次只是修改了一条边,所以我们匹配的时候有一部分是跟原来一样的。
然后我们把 dfs 改成 bfs,就可以真正的变成 \(O(n^3)\)。
代码
\(O(n^4)\) 版
#include<cstdio>
#include<iostream>
#define ll long long
#define INF 0x3f3f3f3f3f3f3f3f
using namespace std;
int n, m, x, y, matched[501];
ll dis[501][501], ex[501], ey[501], slack[501], w;
bool inx[501], iny[501];
bool match(int now) {
iny[now] = 1;
for (int i = 1; i <= n; i++) {
if (inx[i]) continue;
ll g = ex[i] + ey[now] - dis[now][i];
if (!g) {
inx[i] = 1;
if (!matched[i] || match(matched[i])) {
matched[i] = now; return 1;
}
}
else slack[i] = min(slack[i], g);
}
return 0;
}
ll KM() {
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) slack[j] = INF;
while (1) {
for (int j = 1; j <= n; j++) inx[j] = iny[j] = 0;
if (match(i)) break;
ll d = INF;
for (int j = 1; j <= n; j++) {
if (!inx[j]) d = min(d, slack[j]);
}
for (int j = 1; j <= n; j++) {
if (iny[j]) ey[j] -= d;
if (inx[j]) ex[j] += d;
else slack[j] -= d;
}
}
}
ll re = 0;
for (int i = 1; i <= n; i++)
re += dis[matched[i]][i];
return re;
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) {
ey[i] = -INF;
for (int j = 1; j <= n; j++)
dis[i][j] = -INF;
}
for (int i = 1; i <= m; i++) {
scanf("%d %d", &x, &y); scanf("%lld", &w);
dis[x][y] = max(dis[x][y], w);
ey[x] = max(ey[x], dis[x][y]);
}
printf("%lld\n", KM());
for (int i = 1; i <= n; i++)
printf("%d ", matched[i]);
return 0;
}
\(O(n^3)\) 版
#include<cstdio>
#include<cstring>
#include<iostream>
#define ll long long
#define INF 0x3f3f3f3f3f3f3f3f
using namespace std;
int n, m, x, y, matched[501], bef[501];
ll dis[501][501], ex[501], ey[501], slack[501], w;
bool iny[501];
void match(int now) {
int x, y = 0, ty = 0;
matched[y] = now;
while (1) {
x = matched[y]; ll d = INF; iny[y] = 1;
for (int i = 1; i <= n; i++) {
if (iny[i]) continue;
if (slack[i] > ex[x] + ey[i] - dis[x][i]) {
slack[i] = ex[x] + ey[i] - dis[x][i];
bef[i] = y;
}
if (slack[i] < d) {
d = slack[i]; ty = i;
}
}
for (int i = 0; i <= n; i++) {
if (iny[i]) ex[matched[i]] -= d, ey[i] += d;
else slack[i] -= d;
}
y = ty;
if (!matched[y]) break;
}
while (y) {
matched[y] = matched[bef[y]];
y = bef[y];
}
}
ll KM() {
for (int i = 1; i <= n; i++) {
memset(iny, 0, sizeof(iny));
for (int j = 1; j <= n; j++) slack[j] = INF;
memset(bef, 0, sizeof(bef));
match(i);
}
ll re = 0;
for (int i = 1; i <= n; i++)
if (matched[i])
re += dis[matched[i]][i];
return re;
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) {
ey[i] = -INF;
for (int j = 1; j <= n; j++)
dis[i][j] = -INF;
}
for (int i = 1; i <= m; i++) {
scanf("%d %d %lld", &x, &y, &w);
dis[x][y] = max(dis[x][y], w);
ey[x] = max(ey[x], dis[x][y]);
}
printf("%lld\n", KM());
for (int i = 1; i <= n; i++)
printf("%d ", matched[i]);
return 0;
}