洛谷 P6622 [省选联考 2020 A/B 卷] 信号传递

链接

P6622


分析

大毒瘤状压。。。

首先注意对于原序列上一个 \(x\rightarrow y\) 的贡献可以拆到 \(x\) 和 \(y\) 上,也就是说
\(\left\{\begin{matrix} g[x]+=k,g[y]+=k \ (x>y)\\ g[x]-=1,g[y]+=1 \ (x<y) \end{matrix}\right.\)
最后把每个数的序号乘上 \(g\) 再求和就是答案。

我们发现这样每个数的贡献只和在它前面的数有哪些有关,所以我们设 \(g[x][S]\) 表示在 \(x\) 之前的数集为 \(S\) 时 \(x\) 的上面的 \(g\)。
我们先把传递序列拆开,记 \(e[i][j]\) 表示从 \(i\) 直接走到 \(j\) 的边数。
于是我们可以轻松做到 \(O(m^2 2^m)\) 求出 \(g\)。

点击查看代码
for(int i=0;i<=S;i++)
	 for(int j=1;j<=m;j++){
		if(i&(1<<(j-1)))continue;
		for(int k=1;k<=m;k++){
			if(j==k)continue; 
			if((i>>(k-1))&1)g[j][i]+=e[k][j]+K*e[j][k];
			else g[j][i]+=-e[j][k]+K*e[k][j];
		}
	}

于是我们有了 \(g\),只需要确定数的序号就可以求答案。
受到 \(g\) 的启发,因为一个数的序号和 \(g\) 的值只和它前面有哪些数有关,不需要知道前面的顺序。
所以我们设 \(f[S]\) 表示已经确定了 \(S\) 数集的顺序。转移时发现 \(S\) 刚好和 \(g\) 的第二维相同,于是我们也可以轻松地做到 \(O(m 2^m)\) 求出 \(f\)。

点击查看代码
memset(f,127,sizeof(f));f[0]=0,num[0]=1;
for(int i=0;i<S;i++)
	for(int j=1;j<=m;j++){
		if(i&(1<<(j-1)))continue;
		f[i|(1<<(j-1))]=min(f[i|(1<<(j-1))],f[i]+num[i]*g[j][i]),
		num[i|(1<<(j-1))]=num[i]+1; 
	}

现在我们有了一个较为成熟的 \(O(m^2 2^m)\) 的做法,此时你能得到 \(60\) 分的好成绩


时间复杂度优化

发现我们时间复杂度的瓶颈 \(g\) 的求解还有明显的可优化空间,因为 \(g\) 其实是可以由之前的 \(S\) 继承来的。只需从 \(S\) 中随便找一个数 \(j\) 去掉,那么根据 \(g\) 原本的求法我们可以得到 \(g[i][S]=g[i][S-(1<<j)]+(e[j][i]+K*e[i][j])-(-e[i][j]+K*e[j][i])\)。
这里随便找的 \(j\) 可以用 lowbit 去找。

于是我们有了时间复杂度 \(O(m 2^m)\) 的做法,得到了 \(80\) 分的好成绩

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define in read()
inline int read(){
	int p=0,f=1;
	char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){p=p*10+c-'0';c=getchar();}
	return p*f;
}
const int N=1e5+5;
const int T=(1<<23);
inline int lowbit(int x){return x&(-x);}
int n,m,K,S;
int e[25][25],a[N];
int g[25][T];
int f[T],num[T];
signed main(){
	n=in,m=in,K=in,S=(1<<m)-1;
	for(int i=1;i<=n;i++)a[i]=in;
	for(int i=1;i<n;i++)e[a[i]][a[i+1]]++;
	
	for(int i=1;i<=m;i++)
		for(int j=1;j<=m;j++)
			if(i!=j)g[i][0]+=-e[i][j]+K*e[j][i];
	for(int i=1,t=lowbit(i),k=0;i<=S;i++,t=lowbit(i),k=0){
		while((1<<k)!=t)k++;k++;
		for(int j=1;j<=m;j++)
			if(!(i&(1<<(j-1))))
				g[j][i]=g[j][i-t]+(1-K)*e[k][j]+(1+K)*e[j][k];
	}	
	memset(f,127,sizeof(f));f[0]=0,num[0]=1;
	for(int i=0;i<S;i++)
		for(int j=1;j<=m;j++){
			if(i&(1<<(j-1)))continue;
			f[i|(1<<(j-1))]=min(f[i|(1<<(j-1))],f[i]+num[i]*g[j][i]),
			num[i|(1<<(j-1))]=num[i]+1; 
		}
	cout<<f[S];
	return 0;
}

空间复杂度优化

我们惊讶的发现空间居然爆了,显然我们空间复杂度的瓶颈也在 \(g\) 数组上,该怎么从 \(g\) 身上榨出一些空间出来呢.
我们注意到 \(g[x][S]\) 这样的 \(x\) 有一些 \(S\) 是无用的,就是 \(S\) 中包含 \(x\) 的情况。这些情况去掉后不会对求解产生影响。
那么我们可以尝试把含有 \(x\) 的 \(S\) 去掉,这样 \(g\) 的大小就从 \(23\times 2^{23}\) 的约 \(736MB\) 减小到 \(23\times 2^{22}\) 约 \(368MB\) 可以获得开O2 \(100\) 分的好成绩
别看说起来挺麻烦的,其实实现非常简单,只需要在用到在原本 \(80\) 分代码的基础上把出现 \(g\) 的地方微调一下第二维,把 \(x\) 的前后拼接起来。

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define in read()
inline int read(){
	int p=0,f=1;
	char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){p=p*10+c-'0';c=getchar();}
	return p*f;
}
const int N=1e5+5;
const int T=(1<<23);
inline int lowbit(int x){return x&(-x);}
inline int gety(int x,int y){return ((y>>x)<<(x-1))+y%(1<<(x-1));}
int n,m,K,S;
int e[25][25],a[N];
int g[25][T>>1];
int f[T],num[T];
#define g(x,y) g[x][gety(x,y)]
signed main(){
	n=in,m=in,K=in,S=(1<<m)-1;
	for(int i=1;i<=n;i++)a[i]=in;
	for(int i=1;i<n;i++)e[a[i]][a[i+1]]++;
	
	for(int i=1;i<=m;i++)
		for(int j=1;j<=m;j++)
			if(i!=j)g(i,0)+=-e[i][j]+K*e[j][i];
	for(int i=1,t=lowbit(i),k=0;i<=S;i++,t=lowbit(i),k=0){
		while((1<<k)!=t)k++;k++;
		for(int j=1;j<=m;j++)
			if(!(i&(1<<(j-1))))
				g(j,i)=g(j,i-t)+(1-K)*e[k][j]+(1+K)*e[j][k];
	}
	
	memset(f,127,sizeof(f));f[0]=0,num[0]=1;
	for(int i=0;i<S;i++)
		for(int j=1;j<=m;j++){
			if(i&(1<<(j-1)))continue;
			f[i|(1<<(j-1))]=min(f[i|(1<<(j-1))],f[i]+num[i]*g(j,i)),
			num[i|(1<<(j-1))]=num[i]+1; 
		}
	cout<<f[S];
	return 0;
}

卡常

最后要想不开 O2 通过需要一定的卡常,这里把我卡完的代码贴出来,有一些常见的卡常,比如 fread,交换数组两维,register int,预处理2的幂,define,还有一个重要的是求解 \(g\) 和求 \(f\) 的两段可以合到一起。(其实仍然没卡过因为原题时限是2s)

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define in read()
#define lowbit(x) (x&(-x))
#define g(x,y) g[((((y)>>(x))<<(x-1))+(y)%(1<<(x-1)))][x]
inline char nc(){
    static char buf[100000],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read(){
    char ch=nc();int sum=0;
    while(!(ch>='0'&&ch<='9'))ch=nc();
    while(ch>='0'&&ch<='9')sum=sum*10+ch-48,ch=nc();
    return sum;
}
const int N=1e5+5;
const int T=(1<<23);
int n,m,K,S;
int e[25][25],a[N];
int g[T>>1][25];
int f[T],num[T];
int lg2[T];
signed main(){
	n=in,m=in,K=in,S=(1<<m)-1,num[0]=1;
	for(register int i=2;i<=S;i++)lg2[i]=lg2[i/2]+1;
	for(register int i=1;i<=n;i++)a[i]=in;
	for(register int i=1;i<n;i++)e[a[i]][a[i+1]]++;	
	
	for(register int i=1;i<=S;i++)f[i]=1000000000;
	
	for(register int i=1;i<=m;i++)
		for(int j=1;j<=m;j++)
			if(i!=j)g(i,0)+=-e[i][j]+K*e[j][i];
	for(int i=1;i<=m;i++)
		f[1<<(i-1)]=g(i,0);		
	for(register int i=1,t=lowbit(i),k=lg2[t]+1;i<=S;i++,t=lowbit(i),k=lg2[t]+1){
		num[i]=num[i-t]+1;
		for(int j=1;j<=m;j++)
			if(!(i&(1<<(j-1))))
				g(j,i)=g(j,i-t)+(1-K)*e[k][j]+(1+K)*e[j][k],
				f[i|(1<<(j-1))]=min(f[i|(1<<(j-1))],f[i]+num[i]*g(j,i));
	}

	cout<<f[S];
	return 0;
}
上一篇:MobileNetV1 V2 V3网络理解+pytorch源码


下一篇:adb input命令详解