“要求让最少个数的一种 'QwQ' 的个数最多的方案”,显然可知主体算法为二分答案,考虑 check()
怎样实现。
“保证对于每个 \(x\),最多有一个 \(a\) 使得 \(a\to x\) 成立”,故此转化关系可以抽象为树或基环树。尽管不保证连通,但我们可以建出超级源点,连接整个森林。对于树的情况,考虑从叶子向根推,计算根节点需要多少额外转化使其合法,若 \(dp_{root} \le m\) 则可行。这一过程可以用 dfs 回溯或拓扑排序完成。对于基环树,只需要缩环为点——tarjan 等求强连通分量即可。
代码复杂,可以使用 namespace
简化,降低调试难度。
下面是 AC 代码:
#include<cstdio>
#include<queue>
inline int min(const int& x,const int& y){return x<y?x:y;}
inline int rd()
{
int x=0,f=1;char c=getchar();
for(;c<'0'||c>'9';c=getchar()) f^=(c=='-');
for(;c>='0'&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
return f?x:-x;
}
const int N=1e6+10;
int n,m,fa[N],tmp[N],a[N],ind[N];
int tot,h[N],ver[N],nxt[N];
int scc,c[N],siz[N];//siz[]记录环中的点数,计算花费
long long val[N],f[N];//val[]计算缩点的点权和
std::queue<int> q;
namespace Tarjan
{
int dfs_clock,dfn[N],low[N];
int tot,h[N],ver[N],nxt[N];
bool vis[N];
int st[N],top;
inline void add(int u,int v)
{
nxt[++tot]=h[u];
ver[tot]=v;
h[u]=tot;
}
inline void tarjan(int u)
{
dfn[u]=low[u]=++dfs_clock;
st[++top]=u,vis[u]=true;
for(int i=h[u];i;i=nxt[i])
{
int v=ver[i];
if(!dfn[v])
{
tarjan(v);
low[v]=min(low[u],low[v]);
}
else if(vis[v]) low[u]=min(low[u],dfn[v]);
}
if(dfn[u]==low[u])
{
int x=st[top--];
c[x]=++scc,siz[scc]=1,val[scc]+=a[x],vis[x]=false;
while(u!=x)
{
x=st[top--];
c[x]=scc,++siz[scc],val[scc]+=a[x],vis[x]=false;
}
}
}
};
inline void add(int u,int v)
{
nxt[++tot]=h[u];
ver[tot]=v;
h[u]=tot;
}
inline bool check(long long k)
{
q.push(0);//注意加入超级源点
for(int i=1;i<=scc;++i)
{
f[i]=val[i];
ind[i]=tmp[i];
if(!ind[i]) q.push(i);
}
long long sum=0;
while(!q.empty()){
int u=q.front();
q.pop();
for(int i=h[u];i;i=nxt[i])
{
int v=ver[i];
if(f[u]<k*siz[u])
{
f[v]-=k*siz[u]-f[u];
f[u]=k*siz[u];
}
--ind[v];
if(!ind[v]) q.push(v);
}
if(f[u]<k*siz[u]) sum+=k*siz[u]-f[u];
if(sum>m) return 0;
}
return sum<=m;
}
int main()
{
n=rd(),m=rd();
for(int i=1;i<=n;++i)
{
fa[i]=rd();
if(fa[i]!=-1&&fa[i]!=i) Tarjan::add(i,fa[i]);
}
for(int i=1;i<=n;++i) a[i]=rd();
for(int i=1;i<=n;++i)
if(!Tarjan::dfn[i]) Tarjan::tarjan(i);
for(int i=1;i<=n;++i)
{
if(fa[i]!=-1&&fa[i]!=i)
add(c[i],c[fa[i]]),++tmp[c[fa[i]]];
}
long long l=1e8,r=0;
for(int i=1;i<=n;++i)
{
l=min(l,a[i]);
r+=a[i];
}
r=(r+m)/n;
while(l<r)
{
long long mid=(l+r+1)>>1;
if(check(mid)) l=mid;
else r=mid-1;
}
printf("%lld\n",l);
}