题意
给出一个图,上面有若干连通块。
对于每个连通块,为有向边组成的图,每个点上有点权。
最多在\(K+1\)个连通块操作,使得遍历这个连通块获得的点权之和最大。
思路
将每个连通块上的环缩点,变成一个DAG然后dp即可。
代码
#include <queue>
#include <stack>
#include <cstdio>
#include <algorithm>
const int MAXN = 200001, MAXM = 2000001;
std::stack<int> s;
std::queue<int> q;
int n, m, k, tot, tot1, dfncnt, col, ans;
int low[MAXN], dfn[MAXN], in_s[MAXN], scc[MAXN], ver[MAXM], next[MAXM], head[MAXM];
int ver1[MAXM], next1[MAXM], head1[MAXM], deg[MAXM];
int vw[MAXN], vv[MAXN], fa[MAXN], used[MAXN];
struct node {
int id, a;
} f[MAXN];
void add(int u, int v) {
ver[++tot] = v;
next[tot] = head[u];
head[u] = tot;
}
void add1(int u, int v) {
ver1[++tot1] = v;
next1[tot1] = head1[u];
head1[u] = tot1;
}
void tarjan(int u) {
low[u] = dfn[u] = ++dfncnt;
s.push(u);
in_s[u] = 1;
int v;
for (int i = head[u]; i; i = next[i])
if (!dfn[v = ver[i]]) {
tarjan(v);
low[u] = std::min(low[u], low[v]);
} else if (in_s[v])
low[u] = std::min(low[u], dfn[v]);
if (dfn[u] == low[u]) {
++col;
fa[col] = col;
int p;
while ((p = s.top()) != u) {
vv[col] += vw[p];
scc[p] = col;
in_s[p] = 0;
s.pop();
}
scc[p = s.top()] = col;
vv[col] += vw[p];
f[col].a = vv[col];
f[col].id = col;
in_s[p] = 0;
s.pop();
}
}
int find(int x) {
return fa[x] = fa[x] == x ? x : find(fa[x]);
}
void dp() {
for (int i = 1; i <= col; i++)
if (!deg[i])
q.push(i);
while (q.size()) {
int u = q.front(), y;
q.pop();
for (int i = head1[u]; i; i = next1[i]) {
y = ver1[i];
f[y].a = std::max(vv[y] + f[u].a, f[y].a);
deg[y]--;
if (!deg[y])
q.push(y);
}
}
}
int cmp(node x, node y) {
return x.a > y.a;
}
int main() {
freopen("azeroth.in", "r", stdin);
freopen("azeroth.out", "w", stdout);
scanf("%d %d", &n, &m);
for (int i = 1, a, b; i <= m; i++) {
scanf("%d %d", &a, &b);
if (a != b)
add(a, b);
}
for (int i = 1; i <= n; i++)
scanf("%d", &vw[i]);
scanf("%d", &k);
for (int i = 1; i <= n; i++)
if (!dfn[i])
tarjan(i);
for (int i = 1; i <= n; i++)
for (int j = head[i]; j; j = next[j])
if (scc[i] != scc[ver[j]]) {
int f1 = find(scc[i]), f2 = find(scc[ver[j]]);
add1(scc[i], scc[ver[j]]);
deg[scc[ver[j]]]++;
fa[f1] = f2;
}
dp();
std::sort(f + 1, f + col + 1, cmp);
k++;
for (int i = 1, f1; i <= col && k; i++)
if (!used[f1 = find(f[i].id)]) {
used[f1] = 1;
ans += f[i].a;
k--;
}
printf("%d", ans);
}
坑
并查集出锅
数组大小没开够(边与点)