树上问题--斯坦纳树

前言:

什么是斯坦纳树问题?就是给你一些点,一些边,这些点中有一些特殊点,求满足所有特殊点联通的情况下,所费价值的最小值,这里的价值可以是边权,也可以是点权。那么怎么求?注意这种方法仅当特殊点的数量较小,因为我们要用状压去dp。

给一道模板题:

树上问题--斯坦纳树

 

 

设dp[i][j]为以i为根,连通图中,联通了状态j所费价值的最小值。

dp[i][s]=min(dp[i][s],dp[i][subs]+dp[i][subs^s]);

这里的subs是s的子集,那么就可以这么转换。

但是如果i的“出度”为1呢,那么就转不过去了?

我们可以这么转:

dp[i][j]=min(dp[i][j],dp[x][j]+w[i][x]);

这里的x和i是有边连着的,这样的话一个状态就可以向周围扩散出去了。

是不是很像那个三角形不等式,于是我们在图上跑一遍最短路。

下面给出dij的模板:

#include<bits/stdc++.h>
using namespace std;
const int maxn=510;
int n,m,k,x,y,z,dp[maxn][4200],p[maxn],head[maxn],cnt,vis[maxn];
struct node
{
    int v,nxt,w;
}e[maxn<<2];
void add(int u,int v,int w)
{
    cnt++;
    e[cnt].v=v;
    e[cnt].w=w;
    e[cnt].nxt=head[u];
    head[u]=cnt;
}
priority_queue< pair<int,int> > q;
void dij(int s)
{
    memset(vis,0,sizeof(vis));
    while(q.size())
    {
        pair<int ,int > a=q.top();
        q.pop();
        if(vis[a.second]) continue ;
        vis[a.second]=1;
        for(int i=head[a.second];i;i=e[i].nxt)
        {
            int v=e[i].v;
            if(dp[v][s]>dp[a.second][s]+e[i].w)
            {
                dp[v][s]=dp[a.second][s]+e[i].w;
                q.push(make_pair(-dp[v][s],v));
            }
        }
    }
}
int main()
{
    memset(dp,0x3f3f3f3f,sizeof(dp));
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);
        add(y,x,z);
    }
    for(int i=1;i<=k;i++)
    {
        scanf("%d",&p[i]);
        dp[p[i]][1<<(i-1)]=0;
    }
    for(int s=1;s<(1<<k);s++)
    {
        for(int i=1;i<=n;i++)
        {
        for(int subs=s&(s-1);subs;subs=s&(subs-1))
        {
            dp[i][s]=min(dp[i][s],dp[i][subs]+dp[i][subs^s]);
        }
        if(dp[i][s]!=0x3f3f3f3f) q.push(make_pair(-dp[i][s],i));
        }
        dij(s);
    }
    printf("%d\n",dp[p[1]][(1<<k)-1]);
    return 0;
}

复杂度还是有点高的,谨慎使用。

例题.

4294 游览计划

这个鬼题害惨我了,define真是不能随便用,一整个大无语。

就是个裸题,直接给代码:

#include<bits/stdc++.h>
#define id(x,y) ((x)*m+(y))
using namespace std;
const int inf=0x3f3f3f3f;
int cnt=0;
int a[105],vis[105],mp[105],b[105][1<<10];
int f[105][1<<10],pre[105][1<<10];
int head[105];
struct edge
{
    int v,nxt;
}e[105<<2];
priority_queue<pair<int,int> > q;
void add(int u,int v)
{
    ++cnt;
    e[cnt].v=v;
    e[cnt].nxt=head[u];
    head[u]=cnt;
}
void dij(int s)
{
    while(q.size())
    {
        int x=q.top().second;
        q.pop();
        if(vis[x]) continue ;
        vis[x]=1;
        for(int i=head[x];i;i=e[i].nxt)
        {
            int v=e[i].v;
            if(f[v][s]>f[x][s]+a[v])
            {
                f[v][s]=f[x][s]+a[v];
                q.push(make_pair(-f[v][s],v));
                pre[v][s]=x;
            }
        }
    }
}
void dfs(int x,int s)
{
    if(b[x][s]) return ;
    b[x][s]=1;
    mp[x]=1;
    if(pre[x][s]!=-1 && f[pre[x][s]][s]+a[x]==f[x][s])
    {
        int tmp=pre[x][s];
        while(tmp!=-1)
        {
            dfs(tmp,s);
            tmp=pre[tmp][s];
        }
    }
    for(int s0=s&(s-1);s0;s0=s&(s0-1))
    {
        if(f[x][s0]+f[x][s^s0]-a[x]==f[x][s]) 
        {
            dfs(x,s0);
            dfs(x,s^s0);
            break;
        }
    }
}
int main()
{   
    int n,m,k=0,pos=-1;
    cin>>n>>m;
    memset(f,0x3f,sizeof(f));
    memset(pre,-1,sizeof(pre));
    for(int i=0;i<n;i++)
    for(int j=0;j<m;j++)
    {
        cin>>a[id(i,j)];
        if(!a[id(i,j)])
        {
            f[id(i,j)][1<<((++k)-1)]=0;
            pos=id(i,j);
        }
        if(i>=1)
        {
            add(id(i-1,j),id(i,j));
            add(id(i,j),id(i-1,j));
        }
        if(j>=1)
        {
            add(id(i,j-1),id(i,j));
            add(id(i,j),id(i,j-1));
        }
    }
    for(int s=1;s<(1<<k);s++)
    {
    for(int i=0;i<n*m;i++)
    {
        for(int s0=s&(s-1);s0;s0=s&(s0-1))
        {
            f[i][s]=min(f[i][s],f[i][s0]+f[i][s^s0]-a[i]);
        }
        vis[i]=0;
        if(f[i][s]!=inf) q.push(make_pair(-f[i][s],i));
    }
    dij(s);
    }
    if(pos==-1) cout<<0<<endl;
    else cout<<f[pos][(1<<k)-1]<<endl;
    dfs(pos,((1<<k)-1));
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<m;j++)
        {
            if(a[id(i,j)] && mp[id(i,j)]) cout<<"o";
            else if(a[id(i,j)] && mp[id(i,j)]==0) cout<<'_';
            else cout<<'x';
        }
        cout<<endl;
    }
    return 0;
}

找路径的话,那就是判断怎么推过来了,是dij过来的,还是递推推过来的,然后顺着往回找。

上一篇:你竟然赶我走


下一篇:MapReduce当中Combiner的用法