arc099_f Eating Symbols Hard
https://atcoder.jp/contests/arc099/tasks/arc099_d
Tutorial
https://img.atcoder.jp/arc099/editorial.pdf
考虑用哈希来判断序列的相等.设\(A\)的哈希值为\(f(A)=\sum A_ibase^i\),设\(g(S)\)表示\(S\)生成的序列\(A\)的\(f(A)\)
那么+-<>对哈希值的影响为
- \(g(+S)=g(S)+1\)
- \(g(-S)=g(S)-1\)
- \(g(>S)=g(S)base\)
- \(g(<S)=g(S)base^{-1}\)
发现第\(i\)个字符可以表示为一次函数 \(h_i\) 的形式.
设\(c=g(S)\),我们要求的就是\((i,j)\)满足
\[h_i \circ h_{i+1} \cdots \circ h_j(0) = c \\ h_n^{-1} \circ h_{n-1}^{-1} \cdots \circ h_i^{-1} \circ h_{i+1} \cdots \circ h_j(0) = h_n^{-1} \circ h_{n-1}^{-1} \cdots \circ h_i^{-1}(c) \\ h_n^{-1} \circ h_{n-1}^{-1} \cdots \circ h_{j+1}^{-1}(0) = h_n^{-1} \circ h_{n-1}^{-1} \cdots \circ h_i^{-1}(c) \]
那么用map即可求解.
Code
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <map>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define inver(a,mod) power(a,mod-2,mod)
#define idx(a,b) ((ll)(a)*mod[1]+(b))
using namespace std;
template<class T> void rd(T &x) {
x=0; int f=1,ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=getchar();}
x*=f;
}
typedef long long ll;
const int mod[2]={998244353,1004535809};
const int maxn=250000+50;
int bs[2],rb[2];
int n; char s[maxn];
int c[2];
map<ll,int> cnt;
struct func {
int k,b;
func(int k=1,int b=0):k(k),b(b){}
inline int f(int x,int mod) {
return ((ll)k*x+b)%mod;
}
} a[maxn][2];
inline func mer(func a,func b,int mod) {
return func((ll)a.k*b.k%mod,((ll)a.b*b.k+b.b)%mod);
}
inline int add(int x,int mod) {return x>=mod?x-mod:x;}
inline int sub(int x,int mod) {return x<0?x+mod:x;}
ll power(ll x,ll y,int mod) {
ll re=1;
while(y) {
if(y&1) re=re*x%mod;
x=x*x%mod;
y>>=1;
}
return re;
}
void init() {
srand((unsigned long long)(new char));
bs[0]=rand(),rb[0]=inver(bs[0],mod[0]);
bs[1]=rand(),rb[1]=inver(bs[1],mod[1]);
}
int main() {
init();
rd(n);
scanf("%s",s+1);
for(int i=n;i>=1;--i) for(int k=0;k<2;++k) {
func f; switch(s[i]) {
case '+': f=func(1,mod[k]-1),c[k]=add(c[k]+1,mod[k]); break;
case '-': f=func(1,1),c[k]=sub(c[k]-1,mod[k]); break;
case '>': f=func(rb[k],0),c[k]=(ll)c[k]*bs[k]%mod[k]; break;
case '<': f=func(bs[k],0),c[k]=(ll)c[k]*rb[k]%mod[k]; break;
}
a[i][k]=mer(f,a[i+1][k],mod[k]);
}
ll an=0;
++cnt[0];
for(int i=n;i>=1;--i) {
an+=cnt[idx(a[i][0].f(c[0],mod[0]),a[i][1].f(c[1],mod[1]))];
++cnt[idx(a[i][0].f(0,mod[0]),a[i][1].f(0,mod[1]))];
}
printf("%lld\n",an);
return 0;
}