题目描述
小T准备在家里摆放几幅画,为此他买来了N幅画和N个画框。为了体现他的品味,小T希望能合理地搭配画与画框,使得其显得既不过于平庸也不太违和。
对于第 幅画与第 个画框的配对,小T都给出了这个配对的平凡度Aij 与违和度Bij 。整个搭配方案的总体不和谐度为每对画与画框平凡度之和与每对画与画框违和度的乘积。具体来说,设搭配方案中第i幅画与第Pi个画框配对,则总体不和谐度为 小T希望知道通过搭配能得到的最小的总体不和谐度是多少。
输入输出格式
输入格式:
输入文件第 行是一个正整数T ,表示数据组数,接下来是T组数据。对于每组数据,第
行是一个正整数N,表示有N对画和画框。第2到第N+1行,每行有N个非负整数,第i+1 行第j个数表示Aij
。第N+2到第2*N+1行,每行有N个非负整数,第i+N+1 行第j个数表示Bij 。
输出格式:
包含T!行,每行一个整数,表示最小的总体不和谐度
输入输出样例
说明
第1幅画搭配第3个画框,第2幅画搭配第1个画框,第3 幅画搭配第2个画框,则总体不和谐度为30
N<=70,T<=3,Aij<=200,Bij<=200
学了最小乘积生成树后发现这题很相似
只不过把生成树换成了KM
还有这题求的是最小值,而KM求的是最大匹配
所以权要取负,算出来的点也取负
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long lol;
struct Node
{
lol x,y;
};
Node operator -(Node u,Node v)
{
return (Node){u.x-v.x,u.y-v.y};
}
Node operator +(Node u,Node v)
{
return (Node){u.x+v.x,u.y+v.y};
}
lol operator *(Node u,Node v)
{
return u.x*v.y-u.y*v.x;
}
lol a[][],b[][],c[][];
bool vis1[],vis2[];
lol E1[],E2[],n,match[],slack[];
lol ans;
bool dfs(int x)
{int i,j;
vis1[x]=;
for (i=;i<=n;i++)
if (vis2[i]==)
{
lol gap=E1[x]+E2[i]-c[x][i];
if (gap==)
{
vis2[i]=;
if (match[i]==-||dfs(match[i]))
{
match[i]=x;
return ;
}
}
else slack[i]=min(slack[i],gap);
}
return ;
}
Node KM()
{int i,j;
lol tmp;
memset(match,-,sizeof(match));
for (i=;i<=n;i++)
{
E2[i]=;
tmp=;
for (j=;j<=n;j++)
{
if (c[i][j]>tmp) tmp=c[i][j];
}
E1[i]=tmp;
}
for (i=;i<=n;i++)
{
for (j=;j<=n;j++)
slack[j]=1e15;
while ()
{
memset(vis1,,sizeof(vis1));
memset(vis2,,sizeof(vis2));
if (dfs(i)) break;
lol tmp=1e15;
for (j=;j<=n;j++)
if (vis2[j]==&&tmp>slack[j]) tmp=slack[j];
for (j=;j<=n;j++)
{
if (vis1[j]) E1[j]-=tmp;
if (vis2[j]) E2[j]+=tmp;
else slack[j]-=tmp;
}
}
}
lol ansa=,ansb=;
for (i=;i<=n;i++)
ansa+=a[match[i]][i],ansb+=b[match[i]][i];
return (Node){-ansa,-ansb};
}
lol cross(Node p,Node A,Node B)
{
return (A-p)*(B-p);
}
void solve(Node B,Node A)
{int i,j;
Node C;
lol x=(A.y-B.y),y=(B.x-A.x);
for (i=;i<=n;i++)
for (j=;j<=n;j++)
c[i][j]=-a[i][j]*x-b[i][j]*y;
C=KM();
if (C.x*C.y<ans) ans=C.x*C.y;
if (cross(C,B,A)>=) return;
solve(B,C);
solve(C,A);
}
int main()
{Node A,B;
int i,j,T;
cin>>T;
while (T--)
{
cin>>n;
ans=2e9;
for (i=;i<=n;i++)
for (j=;j<=n;j++)
scanf("%lld",&a[i][j]);
for (i=;i<=n;i++)
for (j=;j<=n;j++)
scanf("%lld",&b[i][j]);
for (i=;i<=n;i++)
for (j=;j<=n;j++)
c[i][j]=-b[i][j];
A=KM();
if (A.x*A.y<ans) ans=A.x*A.y;
for (i=;i<=n;i++)
for (j=;j<=n;j++)
c[i][j]=-a[i][j];
B=KM();
if (B.x*B.y<ans) ans=B.x*B.y;
solve(B,A);
printf("%lld\n",ans);
}
}