仅供自己学习
思路:
按照题目意思,规约为要获得一棵最小生成树。只是边权变成了曼哈顿距离。那么就有两种MST的方法
第一种就是prim算法:
基本的prim算法就是维护两个数据结构,第一个是最小生成树集合MST,和其他每个点到最小生成树的最短距离的数组lowcost。
一开始我们就建立一个邻接矩阵存放点与点之间的曼哈顿距离,在建立上述两个集合,最小生成树集合的数组初始化全为-1,当加入一个节点就把该节点的位置的-1改为0,其他点到最小生成树的最短距离的数组初始化全为INT_MAX。
然后把start点加入最小生成树集合,并更新上述两个数组。其他点到最小生成树集合的最短距离是每个点直连的距离不能选有中间节点的路径,尽管这样做距离会更小。
然后在把其余的点加进来,每次加的都是当前距离最小生成树距离最短的点,我们遍历lowcost获得最小的距离,然后把该点加入进MST并记录这个点的距离和下标,然后总距离加上这个距离。然后更新lowcost这个集合,如果该节点未加入进MST并且当前lowcost大于其他节点到达新加入节点的距离,那么就替换成最小的。
代码:
1 class Solution { 2 public: 3 int prim(vector<vector<int>>& points,int start) { 4 int size=points.size(); 5 int res=0; 6 vector<int> MST(size,-1); 7 vector<int> lowcost(size,INT_MAX); 8 vector<vector<int>> g(size, vector<int>(size)); 9 for(int i=0;i<size;++i){ 10 for(int j=i+1;j<size;++j){ //构建无向图,每个节点之间赋值曼哈顿距离。j=i+1减少重复赋值 11 int cost=abs(points[i][0]-points[j][0])+abs(points[i][1]-points[j][1]); 12 g[i][j]=cost; 13 g[j][i]=cost; 14 } 15 } 16 17 MST[start]=0; //加入start节点并且更新两个数组 18 for(int i=0;i<size;++i){ 19 if(i==start) continue; 20 lowcost[i]=g[i][start]; 21 } 22 23 for(int i=1;i<size;++i){//加入其他节点做的处理 24 int min_idx=-1; 25 int min_cost=INT_MAX; 26 for(int j=0;j<size;j++){ 27 if(MST[j]==0) continue; 28 if(lowcost[j]<min_cost){ 29 min_idx=j; 30 min_cost=lowcost[j]; 31 } 32 } 33 34 MST[min_idx]=0; 35 lowcost[min_idx]=-1; 36 res+=min_cost; 37 38 for(int j=0;j<size;j++){ 39 if(MST[j]==-1&&g[j][min_idx]<lowcost[j]){ 40 lowcost[j]=g[j][min_idx]; 41 } 42 } 43 } 44 return res; 45 } 46 int minCostConnectPoints(vector<vector<int>>& points){ 47 return prim(points,0); 48 } 49 };
prim算法是可以用堆进行优化,节省寻找最短距离的时间。
这次用邻接链表存储无向图,但不在存储距离。堆的元素按照(距离,下标)来建立,这样才能按照距离从小到大的排序。
刚开始是一样的步骤,只是建立了堆后,加入(0,start)后 进入while循环,循环条件为堆不为空。每次取出栈顶元素,默认是最短距离,判断是否已经加入MST集合,没有就加入,并且res加上这个最短距离,记录这个点的下标。遍历该点的邻接链表,并计算邻接链表里的节点到该点的曼哈顿距离,判断如果遍历到的节点没有加入到MST并且新的曼哈顿距离小于原lowcost里存储的距离,那么就更新lowcost。然后将该点和距离加入进堆里。
代码;
1 class Solution { 2 public: 3 int prim(vector<vector<int>>& points,int start) { 4 int size=points.size(); 5 if(size==0) return 0; 6 int res=0; 7 vector<int> MST(size,-1); 8 vector<int> lowcost(size,INT_MAX); 9 vector<vector<int>> g 10 for(int i=0;i<size;++i){ 11 for(int j=i+1;j<size;++j){ //建立邻接链表 12 if(i==j) continue; 13 g[i].push_back(j); 14 g[j].push_back(i); 15 } 16 } 17 priority_queue<pair<int,int>,vector<pair<int,int>>,greater<>> pq; 18 pq.push(start); 19 while(!pq.empty()){ 20 auto [dis,idx]=pq.top(); pq.pop(); 21 if(MST[idx]==0)continue; 22 MST[idx]=0; 23 res+=dis; 24 for(int j=0;j<size;j++){ 25 int neighbor=g[idx][j]; 26 int absdis=abs(points[idx][0]-points[j][0])+abs(points[idx][1]-points[j][1]); 27 if(MST[j]==-1&&absdis<lowcost[j]){ 28 lowcost[j]=absdis; 29 pq.push(make_pair(absdis,j)); 30 } 31 } 32 } 33 return res; 34 } 35 int minCostConnectPoints(vector<vector<int>>& points){ 36 return prim(points,0); 37 } 38 };
kruskal算法和prim算法区别在于,kruskal针对边,对边进行排序,然后每次都取最小并且和含有与MST里节点不同源的两端节点的边,因为要排序和源点统计所以适用于稀疏图。而prim算法按照节点来,每次都取离MST最近的一个点的边,耗时最多的就是维护最小距离的数组和找到最小距离的点,适用于稠密图。
这里kruskal用查并集,所以调用查并集的模板,然后建立边,以起点终点和距离为边的属性。然后对边进行从小到大的排序,然后每次取最小距离的边,并将边的起点和终点和边长度传入查并集进行 是否同源的判断,如果同源就返回-1,换另一个边继续,如果不同源就将两个边加入进MST即两个点与MST中的其他点同源,当MST中的点与总点数相同就返回查并集中对加入了MST的边的长度的统计。
代码:
1 class Djset{ 2 public: 3 vector<int> parent; 4 vector<int> rank; 5 vector<int> size; 6 vector<int> len; 7 int num; 8 Djset(int n): parent(n),rank(n),len(n,0),size(n,1),num(n){ 9 for(int i=0;i<n;i++) 10 parent[i]=i; 11 } 12 int find(int x){ 13 if(x!=parent[x]){ 14 parent[x]=find(parent[x]); 15 } 16 return parent[x]; 17 } 18 int merge(int x,int y,int length){ 19 int rootx=find(x); 20 int rooty=find(y); 21 if(rootx!=rooty){ 22 if(rank[rootx]<rank[rooty]){ //如果rootx的深度小于rooty 23 swap(rootx,rooty); //交换两个节点,此时rootx>rooty,交换的原因是后面的代码都能是默认rootx大的情况进行操作 24 } 25 parent[rooty]=rootx; //将rootx这个深度大于rooty的源节点作为rooty的源节点 26 if(rank[rootx]==rank[rooty]) rank[rootx]+=1; //虽然rootx源节点已经为rooty的源节点,但是rank没有改变,如果进入merge函数的时候两深度相等,那么rootx的源节点作为rooty源节点后,那么深度就会增加一个节点长度。 27 size[rootx] += size[rooty]; 28 len[rootx] += len[rooty] + length; 29 if(size[rootx]==num) return len[rootx]; 30 } 31 return -1; 32 } 33 }; 34 35 struct Edge { 36 int start; 37 int end; 38 int len; 39 }; 40 41 class Solution { 42 public: 43 int minCostConnectPoints(vector<vector<int>>& points) { 44 int res=0; 45 int n=points.size(); 46 Djset ds(n); 47 vector<Edge> Edges; 48 for(int i=0;i<n;++i){ 49 for(int j=i+1;j<n;++j){ 50 Edge edge={i,j,abs(points[i][0]-points[j][0])+abs(points[i][1]-points[j][1])}; 51 Edges.emplace_back(edge); 52 } 53 } 54 55 sort(Edges.begin(),Edges.end(),[](const auto& a,const auto& b){ 56 return a.len<b.len; 57 }); 58 59 for(auto& e: Edges){ 60 res=ds.merge(e.start,e.end,e.len); 61 if(res!=-1) return res; 62 } 63 return 0; 64 } 65 };