豪华升级版同余类最短路……
主要写几个小trick:
\(1.O(nm)\)实现同余类最短路:
设某一条边长度为\(x\),那么我们选择一个点,在同余类上不断跳\(x\),可以形成一个环。
显然只有在同一个环上的两点之间才可能通过\(x\)进行转移。我们选择环上答案最小的点,它一定不会在当次更新时被更新答案,所以直接从这个点开始依次遍历环上的所有点,每一个点尝试从前面的一个点更新答案。
\(2.\)将\(\mod n\)的同余类最短路变为\(\mod d\)的同余类最短路:
令新的同余类最短路为\(g_x\),原同余类最短路为\(f_x\),那么首先令\(g_{f_i \mod d} \leftarrow f_i\),但是可能会有一些\(g\)没有被正确更新。在\(f_x\)中实际上还有默认的长度为\(n\)的边,那么在\(g\)中用长度为\(n\)的边在\(g\)上更新一次同余类最短路就可以得到正确的答案了。
\(3.\)更新\(border\)长度为等差数列的一段数的操作过程:
设这一个等差数列的首项为\(x\),公差为\(y\),有\(t+1\)项,先将原最短路变为\(\mod x\)的同余类最短路,那么对于每一个环上的点,可以从前面\(t\)个点进行转移,代价为距离\(\times y + x\),本质是一个多重背包。与此同时类似trick1地,转移一定不会跨越一个环上的最小值点,所以可以破环成链变为多重背包问题,使用单调队列优化转移。
PS:UOJ EX5好毒瘤啊……
Update on 2019.12.28:重写了一遍终于过了UOJ的Ex test,特来还愿。下面的代码更新为可以通过UOJ的代码。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int _ = 5e5 + 7;
char str[_]; ll mn[_] , tmp[_] , W; int T , L , curl , nxt[_];
template < typename T >
void chkmin(T &a , T b){a = a < b ? a : b;}
int upd(int x){return x + (x >> 31 & curl);}
int num[_];
void ext(int len){
int stp = len % curl; if(!stp) return;
memset(num , 0 , sizeof(int) * (curl + 1));
for(int i = 0 ; !num[i] ; ++i){
int x = i , id = i;
do{id = mn[id] > mn[x] ? x : id; x = upd(x + stp - curl); num[x] = 1;}while(x != i);
x = id; do{int p = upd(x + stp - curl); chkmin(mn[p] , mn[x] + len); x = p;}while(x != id);
}
}
ll val[_]; int que[_] , hd , tl;
void ext1(int len , int tms){
int stp = len % curl; if(!stp) return;
memset(num , 0 , sizeof(int) * (curl + 1));
for(int i = 0 ; !num[i] ; ++i){
int x = i , id = i , cnt = 0;
do{id = mn[id] > mn[x] ? x : id; x = upd(x + stp - curl);}while(x != i);
x = id; hd = tl = 0;
do{
num[x] = ++cnt; if(hd != tl && cnt - num[que[hd]] > tms) ++hd;
if(hd != tl) chkmin(mn[x] , val[que[hd]] + 1ll * stp * cnt + curl);
val[x] = mn[x] - 1ll * stp * cnt;
while(hd != tl && val[que[tl - 1]] >= val[x]) --tl;
que[tl++] = x; x = upd(x + stp - curl);
}while(x != id);
}
}
void change(int nl){
memset(tmp , 0x3f , sizeof(ll) * max(tl , nl));
for(int i = 0 ; i < curl ; ++i) chkmin(tmp[mn[i] % nl] , mn[i]);
int tl = curl; curl = nl; memcpy(mn , tmp , sizeof(ll) * max(tl , nl)); ext(tl);
}
int main(){
nxt[0] = -1;
for(scanf("%d" , &T) ; T ; --T){
scanf("%d %lld %s" , &L , &W , str + 1); W -= L;
memset(mn , 0x3f , sizeof(mn)); mn[0] = 0;
for(int i = 1 ; i <= L ; ++i){
int t = nxt[i - 1];
while(~t && str[t + 1] != str[i]) t = nxt[t];
nxt[i] = t + 1;
}
vector < int > border; int t = curl = L , pos = 0;
while(nxt[t]) border.push_back(L - (t = nxt[t]));
while(pos < border.size())
if(pos + 2 >= border.size()) ext(border[pos++]);
else if(border[pos + 2] - border[pos + 1] == border[pos + 1] - border[pos]){
int r = pos + 2 , len = border[pos + 1] - border[pos];
while(r + 1 < border.size() && border[r + 1] - border[r] == len) ++r;
change(border[pos]); ext1(len , r - pos); pos = r + 1;
}
else ext(border[pos++]);
ll sum = 0;
for(int i = 0 ; i < curl ; ++i)
if(W >= mn[i]) sum += (W - mn[i]) / curl + 1;
cout << sum << endl;
}
return 0;
}