题目
点这里看题目。
分析
首先对数组进行排序。然后我们先给每组分配最大值和最小值。这样每对最大值和最小值在排序后的数组上可以表示为一条线段。而没有被选定的点,其贡献的方案数为覆盖它的线段数量。根据乘法原理,此时总方案数为每个未选定的点的贡献的积。
至于计算方案,我们不难想到用 DP 。可以发现当前点的贡献仅与会覆盖它的线段的数量有关。对于一个线段已确定的情况,会覆盖到点\(i\)的线段的数量为\(i\)的左侧的左端点数量减去右端点数量。我们可以理解为\(i\)的左侧未闭合线段的数量。
因此可以设计 DP :
\(f(i,j,k)\):前\(i\)个人,还有\(j\)个未分配完的组(即未闭合线段),不平衡总和为\(k\)的方案数。
转移考虑当前人的分配情况:
\[\begin{aligned} f(i,j,k)&=f(i-1,j,k)~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~(\text{单独一组})\\ &+f(i-1,j-1,k-a_i)~~~~~~~~~~~~~~~~~~~~(\text{新开一组})\\ &+j\times f(i-1,j,k)~~~~~~~~~~~~~~~~~~~~~~~~~~~(\text{加入一组,但是不作为最大值})\\ &+(j+1)\times f(i-1,j+1,k+a_i)~~~~~(\text{作为一组的最大值})\\ \end{aligned} \]
这样转移是\(O(n^2\sum a)\)的,因为 DP 途中\(k\)的取值并没有范围,只有在统计结果的时候才有范围。
这里可以使用差分。可以发现,如果选定最小值\(a_l\),最大值\(a_r\),那么这个序列的不平衡值可以等价地表示为:
\[a_r-a_l=\sum_{i=l+1}^r (a_i-a_{i-1}) \]
我们通过这种方法将不平衡值均摊到了线段的每一个元素上,因此我们可以在 DP 过程中 " 动态 " 地维护不平衡值的和。在 DP 过程中,\(f(i,j,k)\)的\(i\)总共贡献的不平衡值即为\(j\times (a_i-a_{i-1})\)。可以发现,由于\(a_i-a_{i-1}\ge 0\),因此我们需要保证 DP 中每一个不平衡值都在题设范围内。故可以重置状态:
\(f(i,j,k)\):前\(i\)个人,还有\(j\)个未分配完的组,当前不平衡值为\(k\)的方案数。
设\(d=a_i-a_{i-1}\),\(t=j\times d\),转移为:
\[\begin{aligned} f(i,j,k)&=f(i-1,j,k-t)\\ &+(j-1)\times f(i-1,j-1,k-t+d)\\ &+j\times f(i-1,j,k-t)\\ &+(j+1)\times f(i-1,j+1,k-t-d) \end{aligned} \]
时间复杂度为\(O(n^2k)\),答案为\(\sum_{i=0}^k f(n,0,i)\)。
代码
#include <cstdio>
#include <algorithm>
const int mod = 1e9 + 7;
#ifdef _DEBUG
const int MAXN = 15, MAXK = 105;
#else
const int MAXN = 205, MAXK = 1005;
#endif
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
int f[MAXN][MAXN][MAXK];
int a[MAXN];
int N, K;
void upt( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int main()
{
read( N ), read( K );
for( int i = 1 ; i <= N ; i ++ ) read( a[i] );
std :: sort( a + 1, a + 1 + N );
f[0][0][0] = 1;
int t, dif;
for( int i = 1 ; i <= N ; i ++ )
for( int j = 0 ; j <= i ; j ++ )
for( int k = 0 ; k <= K ; k ++ )
{
dif = a[i] - a[i - 1], t = j * dif;
if( k >= t ) upt( f[i][j][k], f[i - 1][j][k - t] );
if( k >= t ) upt( f[i][j][k], 1ll * f[i - 1][j][k - t] * j % mod );
if( j && k >= ( t - dif ) ) upt( f[i][j][k], f[i - 1][j - 1][k - t + dif] );
if( j < i && k >= ( t + dif ) ) upt( f[i][j][k], 1ll * f[i - 1][j + 1][k - t - dif] * ( j + 1 ) % mod );
}
int ans = 0;
for( int i = 0 ; i <= K ; i ++ ) upt( ans, f[N][0][i] );
write( ans ), putchar( '\n' );
return 0;
}