记操作序列为$S$,令$h(S)\equiv \sum_{i}a_{i}x^{i}(mod\ p)$(其中$a_{i}$为操作后的结果)
(以下我们将$S$看作字符串,相邻即拼接操作)
对于操作,有$h(1S)=xh(S)$,$h(3S)=h(S)+1$(另外两种操作类似),这可以看作一个函数,即定义函数$g_{S_{1}}(h(S_{2}))=h(S_{1}S_{2})$
令$s[i,j]$表示操作序列的区间$[i,j]$的子串,则有$g_{s[1,i)}h(s[i,j])=h([1,j])$,同时区间$[i,j]$合法当且仅当$h(s[i,j])=h[s(1,n)]$,即等价于$pre_{j}=g_{s[1,i)}(pre_{n})$(其中$pre_{j}=h(s[1,j])$)
发现右边仅与$i$有关,倒序枚举$i$求出该值,然后在$[i,n]$中找到相同的$pre_{j}$数量,可以用map维护,时间复杂度为$o(n\log_{2}n)$
(后面的值计算可能比较麻烦,可以将$s[1,i)$中的位移和权值拆开来计算)
考虑哈希冲突的概率,假设$x$为变量,那么$h(S)$就是一个关于$x$的函数,且其次数至多为$2n$(算上负幂次),因此$h(S_{1})=h(S_{2})$也就是一个$2n$次的同余方程,由于$p$为大素数,解数量基本为$2n$个
假设选择了$k$个$x$,而$2n$个解就会使得$(2n)^{k}$组$x$不合法,共要判断$o(n^{2})$对哈希值,即会使得$o(n^{2}(2n)^{k})$(同样忽略此处相同)组$x$不合法
对于总共$p^{k}$种,不能选$n^{2}(2n)^{k}$种,即不合法概率为$n^{2}(\frac{2n}{p})^{k}$,取$k=6$,$p$~$10^{9}$可以基本避免冲突
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 250005 4 #define mod 1000000007 5 struct ji{ 6 int a[6]; 7 bool operator < (const ji &k)const{ 8 for(int i=0;i<6;i++) 9 if (a[i]!=k.a[i])return a[i]<k.a[i]; 10 return 0; 11 } 12 }o,a[N]; 13 map<ji,int>mat; 14 int n,p[N],mi[6][N],mi_inv[6][N],x[6]={998244311,998244341,998244353,998244389,998244391,998244397}; 15 long long ans; 16 char s[N]; 17 int ksm(int n,int m){ 18 if (!m)return 1; 19 int s=ksm(n,m>>1); 20 s=1LL*s*s%mod; 21 if (m&1)s=1LL*s*n%mod; 22 return s; 23 } 24 int calc(int p,int k){ 25 if (k>=0)return mi[p][k]; 26 return mi_inv[p][-k]; 27 } 28 int main(){ 29 scanf("%d%s",&n,s); 30 for(int i=0;i<5;i++){ 31 mi[i][0]=1; 32 for(int j=1;j<=n;j++)mi[i][j]=1LL*mi[i][j-1]*x[i]%mod; 33 mi_inv[i][1]=ksm(x[i],mod-2); 34 for(int j=2;j<=n;j++)mi_inv[i][j]=1LL*mi_inv[i][j-1]*mi_inv[i][1]%mod; 35 } 36 p[0]=0; 37 for(int i=1;i<=n;i++){ 38 p[i]=p[i-1]; 39 a[i]=a[i-1]; 40 if (s[i-1]=='<')p[i]--; 41 if (s[i-1]=='>')p[i]++; 42 if (s[i-1]=='+') 43 for(int j=0;j<6;j++)a[i].a[j]=(a[i].a[j]+calc(j,p[i]))%mod; 44 if (s[i-1]=='-') 45 for(int j=0;j<6;j++)a[i].a[j]=(a[i].a[j]+mod-calc(j,p[i]))%mod; 46 } 47 for(int i=n;i;i--){ 48 for(int j=0;j<6;j++)o.a[j]=(1LL*a[n].a[j]*calc(j,p[i-1])+a[i-1].a[j])%mod; 49 mat[a[i]]++; 50 ans+=mat[o]; 51 } 52 printf("%lld",ans); 53 }View Code