[题目链接]
http://codeforces.com/contest/1456/problem/E
[题解]
考虑把所有的区间在值域 \([0 , 2 ^ k)\) 的线段树上拆分成 \(O(k)\) 个区间。
定义线段树的根深度为 \(0\) , 那么对于任意一个第 \(i\) 层的节点都对应着最高 \(i\) 位固定 , 后 \((k - i)\) 位任意的一个区间。
拆完区间后 , 我们考虑对于每个 \(a_{i}\) 定位到线段树上的节点计算贡献。
不妨设 \(f_{i , l , r , x , y}\) 表示 \(\geq i\) 的的位 , \(a_{l..r}\) 所有的数都只能选择 \(\geq 2^{i}\) 的子区间 , \(a_{l - 1}\) 与 \(a_{r + 1}\) 分别为 \(x , y\) 的最小代价。
若 \(a_{l..r}\) 子区间长度都大于 \(2 ^ {i}\)
有 \(f_{l,r,i,x,y}\leftarrow f_{l,r,i+1,x,y}+c_i[\lfloor\frac x{2^i}\rfloor\bmod 2\ne\lfloor\frac y{2^i}\rfloor\bmod 2]\)
否则 \(f_{l,r,i,x,y}\leftarrow f_{l,mid-1,i,x,z}+f_{mid+1,r,i,z,y}\)
时间复杂度 : \(O(N ^ 4)\)
[代码]
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MN = 55;
const LL INF = 1e18;
int n, k, w[MN][MN];
LL c[MN] , f[MN][MN][MN][4][4];
vector < int > fa[MN][MN];
vector < bool > is[MN][MN];
vector < LL > a[MN][MN];
inline void change(LL l , LL r , LL s , LL e , int i , int j , int f) {
if (e < l || s > r) return;
a[i][j].emplace_back(l); fa[i][j].emplace_back(f); ++w[i][j];
if (s <= l && r <= e) return (void) is[i][j].push_back(1);
else is[i][j].push_back(0);
LL mid = l + r >> 1; int tmp = w[i][j] - 1;
change(l , mid , s , e , i , j - 1 , tmp);
change(mid + 1 , r , s , e , i , j - 1 , tmp);
}
inline LL dp(int l , int r , int num , int x , int y) {
if (num == k && l > r) return 0;
if (f[l][r][num][x][y] != -1) return f[l][r][num][x][y];
LL res = num < k ? dp(l , r , num + 1 , fa[l - 1][num][x] , fa[r + 1][num][y]) : INF;
if (res < INF && l > 1 && r < n && ((a[l - 1][num][x] >> num) & 1) != ((a[r + 1][num][y] >> num) & 1))
res += c[num];
for (int i = l; i <= r; ++i)
for (int j = 0; j < w[i][num]; ++j) {
if (!is[i][num][j]) continue;
LL lc = dp(l , i - 1 , num , x , j) , rc = dp(i + 1 , r , num , j , y);
if (lc + rc < res) res = lc + rc;
}
return f[l][r][num][x][y] = res;
}
int main() {
LL l , r;
scanf("%d%d" , &n , &k);
for (int i = 0; i <= k; ++i)
fa[0][i].emplace_back(0) , fa[n + 1][i].emplace_back(0);
for (int i = 1; i <= n; ++i) {
LL l , r; scanf("%lld%lld" , &l , &r);
change(0 , (1LL << k) - 1 , l , r , i , k , 0);
}
for (int i = 0; i < k; i++) scanf("%lld" , &c[i]);
memset(f, -1, sizeof(f));
printf("%lld\n" , dp(1 , n , 0 , 0 , 0));
return 0;
}