基于A*算法的KSP问题求解

#include <bits/stdc++.h>

using namespace std;

constexpr int INF = 1e9;

vector<int> H; // h(x)

void LoadTopo(vector<vector<int>> &topo, int nodeNum)
{
    topo.resize(nodeNum + 1);
    int u, v, w;
    // while (cin >> u >> v >> w) {
    //     topo[u].push_back(v);
    // }
    while (~scanf("%d %d %d", &u, &v, &w)) {
        topo[u].push_back(v);
    }
    return;
}

void ReverseTopo(const vector<vector<int>> &topo, vector<vector<int>> &revTopo)
{
    revTopo.resize(topo.size());
    for (int u = 0; u < topo.size(); ++u) {
        for (int v : topo[u]) {
            revTopo[v].push_back(u);
        }
    }
}

constexpr int TOPO_LINK_DEFAULT_COST  = 1;

void Dijkstra(const vector<vector<int>> &topo, vector<int> &cost, int src)
{
    using P = pair<int, int>;
    priority_queue<P, vector<P>, greater<P>> que;
    cost.resize(topo.size());
    fill(cost.begin(), cost.end(), INF);
    cost[src] = 0;
    que.push((P){cost[src], src});
    while (!que.empty()) {
        P curP = que.top();
        que.pop();
        int u = curP.second;
        if (cost[u] < curP.first) {
            continue;
        }
        for (int v : topo[u]) {
            if (cost[v] > cost[u] + TOPO_LINK_DEFAULT_COST) {
                cost[v] = cost[u] + TOPO_LINK_DEFAULT_COST;
                que.push((P){cost[v], v});
            }
        }
    }
}

void PrintPath(const vector<int> &path)
{
    cout << "path cost : " << path.size() - 1 << endl;
    for (int i = 0; i < path.size(); ++i) {
        cout << path[i];
        if (i != path.size() - 1) {
            cout << "->";
        }
    }
    cout << endl;
}

typedef struct CurTraverStatus {
    int lstId;
    int curId;
    int curNodeId;
    int g; // g(x)

    void pt()
    {
        cout << "curNodeId " << curNodeId << " lstId " << lstId << " curId " << curId << " val: " << g << " " << H[curNodeId] << endl; 
    }

    bool operator < (const CurTraverStatus &other) const {
        return g + H[curNodeId] > other.g + H[other.curNodeId]; 
    }
} CurTraverStatus;

void Recordpath(map<int, CurTraverStatus> &mem, CurTraverStatus curStatus, int src)
{
    vector<int> path;
    CurTraverStatus &cur = curStatus;
    while (true) {
        path.push_back(cur.curNodeId);
        if (cur.curNodeId == src) {
            break;
        }
        cur = mem[cur.lstId];
    }
    reverse(path.begin(), path.end());
    PrintPath(path);
}

void KSPByAStartAlgo(const vector<vector<int>> &topo, int src, int dest, int k)
{
    static int ID = 0;
    map<int, CurTraverStatus> mem;
    priority_queue<CurTraverStatus> que;
    CurTraverStatus first = (CurTraverStatus){ID, ID++, src, 0};
    mem[first.curId] = first;
    que.push(first);
    vector<int> cnt(topo.size(), 0);
    vector<bool> isInQueue(topo.size(), false);
    isInQueue[src] = true;
    while (!que.empty()) {
        CurTraverStatus cur = que.top();
        que.pop();
        // cur.pt();
        cnt[cur.curNodeId]++;
        if (cnt[cur.curNodeId] > k) {
            continue;
        }
        if (cur.curNodeId == dest) {
            // 第 cnt[cur.curNodeId] 小的路径生成成功
            // cout << "\ncurK = " << cnt[cur.curNodeId] << endl;
            // Recordpath(mem, cur, src);
            if (cnt[cur.curNodeId] == k) {
                break;
            }
        }
        for (int v : topo[cur.curNodeId]) {
            // cout << cur.curNodeId << " -> " << v << endl;
            if (isInQueue[v]) {
                // 该判断避免因为想走短的路绕环
                continue;
            }
            CurTraverStatus newStatus = (CurTraverStatus){cur.curId, ID++, v, cur.g + TOPO_LINK_DEFAULT_COST};
            mem[newStatus.curId] = newStatus;
            que.push(newStatus);
        }
    }
}

int main()
{
    freopen("in.txt", "r", stdin);
    // ios::sync_with_stdio(0);
    // cin.tie(0);
    clock_t start;
    clock_t end;

    start = clock();

    int nodeNum, edgeNum, src, dest, k;
    // cin >> nodeNum >> edgeNum >> src >> dest >> k;
    scanf("%d %d %d %d %d", &nodeNum, &edgeNum, &src, &dest, &k);

    vector<vector<int>> topo;
    LoadTopo(topo, nodeNum);

    end = clock();

    cout << "loadTopo " << (double)(end - start) / CLOCKS_PER_SEC << "s" << endl;
    

    start = clock();

    vector<vector<int>> revTopo;
    ReverseTopo(topo, revTopo);

    Dijkstra(revTopo, H, dest);
    if (H[src] == INF) {
        cout << "can't arrival" << endl;
        return 0;
    }

    KSPByAStartAlgo(topo, src, dest, k);

    end = clock();
    cout << "calc " << (double)(end - start) / CLOCKS_PER_SEC << "s" << endl;
    return 0;
}

制造用例python代码

n = 4000
m = n * (n - 1)
s = 1
t = n
k = 1000
print('{0} {1} {2} {3} {4}'.format(n, m, s, t, k))
for i in range(1, n + 1):
    for j in range(i + 1, n + 1):
        print('{0} {1} 1'.format(i, j))
        print('{0} {1} 1'.format(j, i))

性能:
4000个点,4000 * 3999 条边取前1000条最短路的开销
基于A*算法的KSP问题求解

使用方法:
1.将上面的python代码的输出重定向到in.txt
2.编译运行上面的cpp代码

上一篇:第十七天python3 文件IO(三)


下一篇:吴恩达深度学习-第一课神经网络和深度学习-第2周课后编程