https://codeforces.com/contest/1188/problem/C
思路:
那部分差分没懂干净阿....郁闷
参考:
https://www.luogu.com.cn/problem/solution/CF1188C
由于题目美丽值和数的顺序无关,可以先对a数组从小到大排序,
那么子序列的美丽度就是对所有相邻数的差值取min,
因为a[i]<=1e5,那么美丽值最大不超过1e5,
假设:
美丽值=1的子序列数量为s[1]
美丽值=2的子序列数量为s[2]
...
美丽值为1的子序列对答案的贡献为s[1]*1
美丽值为2的子序列对答案的贡献为s[2]*2
美丽值为3的子序列对答案的贡献为s[3]*3
设s[]的后缀和为ss[]
那么s[1]*1+s[2]*2+s[3]*3...=ss[1]+ss[2]+ss[3]...
ss[]是s[]的后缀和,同时ss[i]也表示满足美丽值>=i的方案数
尝试计算ss[]:
对于一个美丽值x,满足子序列相邻元素差值都>=x的方案数就是ss[x]
枚举x,对于每一个x单独计算:
令d[i][j]为前i个数,选j个且以i结尾,差值>=x的方案数
那么转移方程为d[i][j]=sum(d[k][j-1]),其中a[i]-a[k]>=x
因为满足a[i]-a[k]>=x的k一定是连续的,且k是递增的
所以可以用一个指针维护这一段的右边端点(左端点为1),同时维护前缀和
这样sum(d[k][j-1])转移就是O(1)的了,dp过程复杂度为O(n*k)
本次计算出的d[n][k]就是ss[x]
x范围最大a[n],那么枚举x+dp的总复杂度为O(n*k*a[n]),还需要优化
一个dp范围优化:
长度为k的合法x子序列,需要满足k-1对相邻数都>=x,
那么在这个子序列中,第k个数和第1个数的差值>=(k-1)*x
考虑x取值的最大范围:
第k个数和第1个数的最大差值显然为a[n]-a[1],
那么a[n]-a[1]>=(k-1)*x,移项一下变为x<=(a[n]-a[1])/(k-1),
(a[n]-a[1])/(k-1)粗略当成a[n]/k
那么总复杂度就不是O(n*k*a[n])了,而是O(n*k*(a[n]/k))=O(n*a[n])
ps:
我的代码是对于d[i][j]的每一个j开了一个指针pos[j],
维护满足a[i]-a[k]>=x的k的最右位置,
同时还开了sum[][]维护d[][]的前缀和,
dp的时候,第一层枚举i,第二层枚举j.
但是看了其他人的代码,似乎是第一层枚举j,第二层枚举i
这样的话对于每一层j,维护sum(d[k][j-1])只用一个变量指针(而不是我的数组),
这样似乎会更快一点?(别人500+ms,我1900+ms)
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<map>
#include<set>
#include<cstdio>
#include<algorithm>
#define debug(a) cout<<#a<<"="<<a<<endl;
using namespace std;
const int maxn=1e3+100;
typedef int LL;
const LL mod=998244353;
inline LL read(){LL x=0,f=1;char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;}
LL a[1010];
long long dp[maxn][maxn],sum[maxn][maxn];
int main(void){
LL n,k;scanf("%d%d",&n,&k);
for(LL i=1;i<=n;i++) scanf("%d",&a[i]);
sort(a+1,a+1+n);
long long ans=0;
a[0]=-0x3f3f3f3f;
for(LL v=1;v*(k-1)<=a[n];v++){///枚举美丽值
LL now=0;LL res=0;
dp[0][0]=1;sum[0][0]=1;
for(LL i=1;i<=n;i++){
while(a[now]<=a[i]-v) now++;
for(LL j=0;j<=k;j++){
if(j) dp[i][j]=sum[now-1][j-1];
sum[i][j]=(sum[i-1][j]+dp[i][j])%mod;
}
res=(res+dp[i][k])%mod;
}
ans=(ans+res)%mod;
}
printf("%lld\n",ans);
return 0;
}