根据常用套路,用一个奇怪的字符把$A$,$B$连接起来,然后二分答案,然后按mid分组。
分完组考虑如何统计每一组的贡献。
对于每一组内每一对$(A_i , B_j)$考虑拆成两部分:
- $rank(A_i) < rank(B_j)$
- $rank(A_i) > rank(B_j)$
然后就可以从小到大枚举每一个串,然后考虑前面的$A_i$或$B_j$的贡献。
显然这个贡献从当前串的前一个串往前走单调不增,然后就拿个单调栈维护就完了。
Code
/**
* poj
* Problem#3415
* Accepted
* Time: 1110ms
* Memory: 10232k
*/
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#ifndef WIN32
#define Auto "%lld"
#else
#define Auto "%I64d"
#endif
using namespace std;
typedef bool boolean;
#define ll long long #define pii pair<int, int>
#define fi first
#define sc second const int N = 2e5 + ; typedef class Pair3 {
public:
int x, y, id; Pair3() { }
Pair3(int x, int y, int id):x(x), y(y), id(id) { }
}Pair3; typedef class SuffixArray {
protected:
Pair3 T1[N], T2[N];
int cnt[N]; public:
int n;
char *str;
int sa[N], rk[N], hei[N]; void set(int n, char* str) {
this->n = n;
this->str = str;
memset(sa, , sizeof(sa));
memset(rk, , sizeof(rk));
memset(hei, , sizeof(hei));
} void radix_sort(Pair3* x, Pair3* y) {
int m = max(n, );
memset(cnt, , sizeof(int) * m);
for (int i = ; i < n; i++)
cnt[x[i].y]++;
for (int i = ; i < m; i++)
cnt[i] += cnt[i - ];
for (int i = ; i < n; i++)
y[--cnt[x[i].y]] = x[i]; memset(cnt, , sizeof(int) * m);
for (int i = ; i < n; i++)
cnt[y[i].x]++;
for (int i = ; i < m; i++)
cnt[i] += cnt[i - ];
for (int i = n - ; ~i; i--)
x[--cnt[y[i].x]] = y[i];
} void build() {
for (int i = ; i < n; i++)
rk[i] = str[i];
for (int k = ; k < n; k <<= ) {
for (int i = ; i + k < n; i++)
T1[i] = Pair3(rk[i], rk[i + k], i);
for (int i = n - k; i < n; i++)
T1[i] = Pair3(rk[i], , i);
radix_sort(T1, T2);
int diff = ;
rk[T1[].id] = ;
for (int i = ; i < n; i++)
rk[T1[i].id] = (T1[i].x == T1[i - ].x && T1[i].y == T1[i - ].y) ? (diff) : (++diff);
if (diff == n - )
break;
}
for (int i = ; i < n; i++)
sa[rk[i]] = i;
} void get_height() {
for (int i = , j, k = ; i < n; i++, (k) ? (k--) : ()) {
if (rk[i]) {
j = sa[rk[i] - ];
while (i + k < n && j + k < n && str[i + k] == str[j + k]) k++;
hei[rk[i]] = k;
}
}
} const int& operator [] (int p) {
return sa[p];
} const int& operator () (int p) {
return hei[p];
}
}SuffixArray; int K;
int n, m;
char S[N];
SuffixArray sa; inline boolean init() {
scanf("%d", &K);
if (!(K--))
return false;
scanf("%s", S);
n = strlen(S);
S[n] = '#';
scanf("%s", S + n + );
m = strlen(S + n + );
n += m + ;
sa.set(n, S);
return true;
} ll res = , sum;
int tp = ;
pii st[N];
inline void solve(int L, int R) { // Calculate the s_i (i \in [L, R))
if (R - L < )
return ;
tp = sum = ;
for (int i = L, sg; i < R - ; i++) {
sg = (sa[i] < n - m - );
if (!sg)
res += sum;
while (tp && st[tp].fi >= sa(i + ))
sg += st[tp].sc, sum -= st[tp].sc * 1ll * (st[tp].fi - K), tp--;
sum += (sa(i + ) - K) * 1ll * sg;
st[++tp] = pii(sa(i + ), sg);
}
if (!(sa[R - ] < n - m - ))
res += sum; tp = sum = ;
for (int i = L, sg; i < R - ; i++) {
sg = !(sa[i] < n - m - );
if (!sg)
res += sum;
while (tp && st[tp].fi >= sa(i + ))
sg += st[tp].sc, sum -= st[tp].sc * 1ll * (st[tp].fi - K), tp--;
sum += (sa(i + ) - K) * 1ll * sg;
st[++tp] = pii(sa(i + ), sg);
}
if (sa[R - ] < n - m - )
res += sum;
} inline void solve() {
res = ;
sa.build();
sa.get_height(); int lst = ;
for (int i = ; i < n; i++)
if (sa(i) < K + )
solve(lst, i), lst = i;
solve(lst, n);
printf(Auto"\n", res);
} int main() {
while (init())
solve();
return ;
}