[八省联考2018]制胡窜 (SAM+大讨论)

正着做着实不太好做,正难则反,考虑反着做。

把i,j看成在切割字符串,我们统计有多少对(i,j)会切割所有与\(s_{l,r}\)相同的串。对于在后缀自动机上表示\(s_{l,r}\)的节点x,x的parent子树内的endpos节点集合,就是和\(s_{l,r}\)相等的串的最后一个字符的出现位置。我们相当于在s串里得到了若干个线段,每个线段表示的子串都和\(s_{l,r}\)相等,然后用两刀把这些串都割了。我们分最左边的串和最右边的串是否存在交集进行讨论。

如果存在交集,线段数量是m

1.第一刀切串[1,i],第二刀切[i+1,m],方案数\((r_{i+1}-r_{i})(r_{i+1}-l_{m})\)

2.第一刀切[1,m],第二刀在第一刀右面随便切,是一个等差数列

3.第一刀切在第一个串左边,第二刀切在交集,一个乘法原理

如果不存在交集

可行的位置收到了限制,我们要求第一刀必须切第一个串,第二刀必须切第m个串,我们讨论出第一刀可行的线段编号区间[L,R],再统计方案数。

总之两种情况都需要维护\(\sum_{i=L}^{R}(r_{i+1}-r_{i})(r_{i+1}-l_{m})\)这个式子,把它拆开。

\[\sum_{i=L}^{R}(r_{i+1}-r_{i})(r_{i+1}-l_{m}) \\=\sum_{i=L}^{R}(\ (r_{i+1}^{2}-r_{i}r_{i+1})-l_{m}(r_{i+1}-r_{i})\ ) \\=\sum_{i=L}^{R}(r_{i+1}^{2}-r_{i}r_{i+1})-l_{m}(r_{R}-r_{L}) \]

常用套路,用线段树合并维护endpos集合,和式第二项维护相邻两项的乘积,对应pushup时左区间max和右区间min,我们需要维护一段区间内最大/最小值,再维护和式即可

#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long 
using namespace std;

template <typename _T> void read(_T &ret)
{
    ret=0; _T fh=1; char c=getchar();
    while(c<'0'||c>'9'){ if(c=='-') fh=-1; c=getchar(); }
    while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); }
    ret=ret*fh;
}
const int N1=1e5+5, S1=N1*2, M1=S1*70, inf=0x3f3f3f3f;
struct EDGE{
int to[S1],nxt[S1],head[S1],cte;
void ae(int u,int v)
{ cte++; to[cte]=v, nxt[cte]=head[u], head[u]=cte; }
}e;
struct node{ ll sum; int mi,ma; 
friend node operator + (const node &s1,const node &s2)
{ return (node){s1.sum+s2.sum-((s2.mi!=inf)?1ll*s1.ma*s2.mi:0ll), min(s1.mi,s2.mi) , max(s1.ma,s2.ma)}; }
};

int n,Q;
char str[N1];
int idx(char c){ return c-'0'; }

struct SEG{
int mi[M1],ma[M1],ls[M1],rs[M1],root[S1],tot; ll sum[M1];
void init(){ mi[0]=inf; }
void pushup(int rt)
{
    mi[rt]=min(mi[ls[rt]],mi[rs[rt]]);
    ma[rt]=max(ma[ls[rt]],ma[rs[rt]]);
    sum[rt]=sum[ls[rt]]+sum[rs[rt]];
    if(mi[rs[rt]]!=inf) sum[rt]-=1ll*ma[ls[rt]]*mi[rs[rt]];
}
void ins(int x,int l,int r,int &rt)
{
    if(!rt) rt=++tot; 
    if(l==r){ mi[rt]=ma[rt]=l; sum[rt]=1ll*l*l; return; }
    int mid=(l+r)>>1;
    if(x<=mid) ins(x,l,mid,ls[rt]);
    else ins(x,mid+1,r,rs[rt]);
    pushup(rt);
}
//位置互不相同 在线段树叶节点一定会return 无需额外特判
int merge(int r1,int r2)
{
    if(!r1||!r2) return r1+r2; 
    int rt=++tot; 
    ls[rt]=merge(ls[r1],ls[r2]);
    rs[rt]=merge(rs[r1],rs[r2]);
    pushup(rt);
    return rt;
}
int lower(int x,int l,int r,int rt)
{
    if(l==r){
        if(mi[rt]<=x) return mi[rt];
        else return -1;
    }
    int mid=(l+r)>>1;
    if(mi[rs[rt]]<=x) return lower(x,mid+1,r,rs[rt]);
    else return lower(x,l,mid,ls[rt]);
}
int upper(int x,int l,int r,int rt)
{
    if(l==r){
        if(ma[rt]>=x) return ma[rt];
        else return -1;
    }
    int mid=(l+r)>>1;
    if(ma[ls[rt]]>=x) return upper(x,l,mid,ls[rt]);
    else return upper(x,mid+1,r,rs[rt]);
}
node query(int L,int R,int l,int r,int rt)
{
    if(L<=l&&r<=R){
        return (node){sum[rt],mi[rt],ma[rt]};
    }
    int mid=(l+r)>>1; node ans=(node){0ll,inf,0};
    if(L<=mid) ans=(ans+query(L,R,l,mid,ls[rt]));
    if(R>mid) ans=(ans+query(L,R,mid+1,r,rs[rt]));
    return ans;
}
}s;

int trs[S1][10],pre[S1],dep[S1],id[S1],tot,la;
void init(){ tot=la=1; }
void insert(int c,int i)
{
    int p=la,np=++tot,q,nq; la=np;
    dep[np]=dep[p]+1; 
    s.ins(i,1,n,s.root[np]); id[i]=np;
    for(;p&&!trs[p][c];p=pre[p]) trs[p][c]=np;
    if(!p){ pre[np]=1; return; }
    q=trs[p][c]; 
    if(dep[q]==dep[p]+1) pre[np]=q;
    else{
        pre[nq=++tot]=pre[q];
        pre[q]=pre[np]=nq;
        dep[nq]=dep[p]+1;
        memcpy(trs[nq],trs[q],sizeof(trs[nq]));
        for(;p&&trs[p][c]==q;p=pre[p]) trs[p][c]=nq;
    }
}
int ff[S1][19];
void dfs(int x)
{ 
    for(int j=2;j<=18;j++) ff[x][j]=ff[ ff[x][j-1] ][j-1];
    for(int j=e.head[x];j;j=e.nxt[j]){
        int v=e.to[j];
        dfs(v);
        s.root[x]=s.merge(s.root[x],s.root[v]);
    }
}
void build()
{
    for(int i=2;i<=tot;i++) e.ae(pre[i],i), ff[i][0]=i, ff[i][1]=pre[i];
    dfs(1);
}

int main()
{
    // freopen("1.in","r",stdin);
    read(n); read(Q);
    scanf("%s",str+1);
    init(); s.init();
    for(int i=1;i<=n;i++) insert(idx(str[i]),i);
    build();
    int l,r,x,len;
    for(int q=1;q<=Q;q++){
        read(l); read(r); len=r-l+1;
        x=id[r];
        // for(;dep[pre[x]]<=len;x=pre[x])
        for(int j=18;j>=0;j--)
            if(dep[ff[x][j]]>=len) x=ff[x][j];
        ll ans=1ll*(n-1)*(n-2)/2,tmp=0;
        int r1=s.mi[s.root[x]], rm=s.ma[s.root[x]], lm=rm-len+1, l1=r1-len+1;
        if(r1>lm){ //s1与sm有交
            tmp+=s.sum[s.root[x]]-1ll*r1*r1-1ll*lm*(rm-r1);
            tmp+=max(0ll,1ll*(2*n-lm-1-r1)*(r1-lm)/2);
            tmp+=max(0ll,1ll*(l1-1)*(r1-lm));
        }else{
            int L=s.lower(lm,1,n,s.root[x]);
            int R=s.lower(r1+len-2,1,n,s.root[x]), lR=R-len+1;
            int nxt=s.upper(R+1,1,n,s.root[x]);
            if(L!=-1 && r!=-1 && L<=R){
                node k=s.query(L,R,1,n,s.root[x]);
                tmp+=k.sum-1ll*L*L-1ll*lm*(R-L);
                tmp+=1ll*(r1-lR)*(nxt-lm);
            }
        }
        ans-=tmp;
        printf("%lld\n",ans);
    }
    // printf("%llu\n",(sizeof(s)+sizeof(ff)+sizeof(e)+sizeof(trs))/1024/1024);
    return 0;
}
 
上一篇:字符串


下一篇:21.7.8 t1