文章目录
1176:遍历
遍历
题意:给定N个点的图,求经过指定M个点的路径,使路径最短
思路:
修改floyd,确保floyd下M个点不会为中间dp节点。得到M个点间的距离。
题目转化为已知M个点之间的距离,求遍历这M个点的顺序使路径最小。
d
p
[
i
]
[
j
]
[
k
]
dp[i][j][k]
dp[i][j][k]中i为起点,j为终点,k是经过的节点的二进制表示。
如
d
p
[
1
]
[
2
]
[
7
]
dp[1][2][7]
dp[1][2][7]中起点为1,终点为2,二进制值为
7
10
=
11
1
2
7_{10}=111_{2}
710=1112,即从1出发到达2,经过节点123的最短路径长度。
如何求答案?
M最大为15。可以拆为M/2和M-M/2两个组成部分。第一部分起点终点为i、j,第二部分起点终点为s、t,则答案ans为:
a
n
s
=
m
i
n
{
a
n
s
,
d
p
[
i
]
[
j
]
[
k
1
]
+
d
p
[
s
]
[
t
]
[
k
2
]
+
d
i
s
[
j
]
[
s
]
}
ans = min\{ans, dp[i][j][k1]+dp[s][t][k2]+dis[j][s]\}
ans=min{ans,dp[i][j][k1]+dp[s][t][k2]+dis[j][s]}
k1是第一部分经过的点的二进制表示,k2是第二部分经过的点的二进制表示。dis是M个点之间的距离矩阵。则只需要枚举i、j、s、t和M/2的方案。
M/2最大为7,则方案数为
C
15
7
=
6435
C_{15}^7=6435
C157=6435。枚举i、j、s、t的复杂度为
O
(
7
×
7
×
8
×
8
)
=
O
(
3136
)
O(7\times 7 \times 8 \times 8)=O(3136)
O(7×7×8×8)=O(3136)。则总复杂度为这两者的乘积
≈
O
(
2
e
7
)
\approx O(2e7)
≈O(2e7)。
如何求dp?
初始化值
d
p
[
i
]
[
i
]
[
1
<
<
(
i
−
1
)
]
=
0
dp[i][i][1<<(i-1)] = 0
dp[i][i][1<<(i−1)]=0,表示i节点待在自身位置上值为0。如果k不为1<<(i-1)的其他情况则为inf,认为不可达。
因为求答案只需要到max{M/2,M-M/2}=(M+1)/2,所以dp求到这就可以了。对于经过的点数num,先处理num小的dp值再处理大的。
对于给定的num的二进制值为k的某一方案,枚举a、b、c,表示起点为a,终点为b,c为最后连接的点。则转移方程为:
d
p
[
a
]
[
c
]
[
k
]
=
m
i
n
{
d
p
[
a
]
[
b
]
[
k
−
(
1
<
<
(
c
−
1
)
)
]
+
d
i
s
[
b
]
[
c
]
}
dp[a][c][k] = min\{dp[a][b][k-(1<<(c-1))]+dis[b][c]\}
dp[a][c][k]=min{dp[a][b][k−(1<<(c−1))]+dis[b][c]}
复杂度计算略,也不会超时。
但是最后只有90分。没有找到问题在哪。
#include<bits/stdc++.h>
#define all(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define pll pair<ll,ll>
using namespace std;
typedef long long ll;
const double eps=1e-8;
const double PI=acos(-1.0);
const ll mod = 1e9+7;
const ll mx = 1e2 + 10;
const ll Mx = 20;
const ll tmx = 4e4+ 10;
ll N, M;
ll d[mx][mx];
bool b[mx];
ll dis[mx][mx], id[mx];
ll dp[Mx][Mx][tmx];//dp[i][j] from i to j
vector<ll>e[20];
ll cid[mx], cid2[mx], t1, t2;
void getid(ll x){
t1 = t2 = 0;
for(ll time = 1; time <= M; time++){
if(x&1)cid[++t1] = time;
else cid2[++t2] = time;
x>>=1;
}
}
ll f(){
//初始化
memset(dp, -1, sizeof dp);
for(ll i = 1; i <= M; i++) dp[i][i][1<<(i-1)]=0LL;
//DP
for(ll num = 2; num <= (M+1)/2; num++){
for(ll _ = 0; _ < e[num].size(); _++){
ll it = e[num][_];
getid(it);
for(ll a = 1; a <= t1; a++)
for(ll b = 1; b <= t1; b++){
//(a->b)->c
ll na = cid[a], nb = cid[b];
for(ll c = 1; c <= t1; c++){
if(c==a || c==b)continue;
ll nc = cid[c];
ll cit = it - (1<<(nc-1));
if(dp[na][nb][cit]==-1LL)continue;
if(dis[nb][nc]==-1LL)continue;
ll nv = dp[na][nb][cit]+dis[nb][nc];
if(dp[na][nc][it]==-1LL)
dp[na][nc][it] = nv;
else
dp[na][nc][it] = min(dp[na][nc][it], nv);
}
}
}
}
ll num1 = M/2, ans = -1LL;
for(ll _ = 0; _ < e[num1].size(); _++){
ll it = e[num1][_];
ll vres = (1<<M)-1-it;
getid(it);
for(ll i = 1; i <= t1; i++)
for(ll j = 1; j <= t1; j++){
ll ni = cid[i], nj = cid[j];
if(dp[ni][nj][it]==-1LL)continue;
for(ll s = 1; s <= t2; s++)
for(ll t = 1; t <= t2; t++){//(ni->nj)->(ns->nt);
ll ns = cid2[s], nt = cid2[t];
if(dp[ns][nt][vres]==-1LL)continue;
if(dis[nj][ns]==-1LL)continue;
ll nv = dp[ni][nj][it]+dp[ns][nt][vres]+dis[nj][ns];
if(ans == -1LL) ans = nv;
else ans = min(ans, nv);
}
}
}
return ans;
}
void solve(){
scanf("%lld %lld", &N, &M);
for(ll i = 1; i <= N; i++){
for(ll j = 1; j <= N; j++) {
scanf("%lld", &d[i][j]);
if(d[i][j] == 0LL && i!=j) d[i][j] = -1;
}
}
ll cM = 1;
for(ll i = 1; i <= M; i++){
scanf("%lld", &id[cM]);
id[cM]++;
if(b[id[cM]])continue;
b[id[cM]]=true;
cM++;
}
M = cM-1;
if(M == 1LL){
printf("0\n");
return;
}
//floyd
for(ll k = 1; k <= N; k++){
if(b[k])continue;
for(ll i = 1; i <= N; i++){
for(ll j = 1; j <= N; j++){
if(d[i][k]==-1LL || d[k][j]==-1LL)continue;
ll nv = d[i][k]+d[k][j];
d[i][j]=(d[i][j]==-1LL?nv:min(d[i][j], nv));
}
}
}
for(ll i = 1; i <=M ;i ++){
for(ll j =1; j <= M; j++){
if(i == j) dis[i][j] = 0LL;
else dis[i][j] = d[id[i]][id[j]];
}
}
if(M == 2LL){
ll ans = -1;
if(dis[1][2]!=-1LL) ans = dis[1][2];
if(dis[2][1]!=-1LL){
if(ans!=-1LL) ans = min(ans, dis[2][1]);
else ans = dis[2][1];
}
printf("%lld\n", ans);
return;
}
//预处理
for(ll i = 1; i < (1<<M); i++){
ll num = 0, x = i;
while(x){
if(x&1)num++;
x>>=1;
}
e[num].pb(i);
}
ll ans = f();
printf("%lld\n", ans);
}
int main(){
solve();
return 0;
}