转化题意,给定n,m,a,b,序列元素为\(f_i=(a* i+b)\mod m+1\),求逆序对个数
首先不难发现序列分成了若干段等差数列,公差为 a
考虑到题目里两个限制
-
n<=m
-
m是质数,a在模m意义下有逆,\(\min (a,a^{-1})<=1000\)
不难计算等差数列的个数为\(O(a)\)
于是若a<=1000,已经有了一个\(O(a^2)\)的做法,枚举两段等差数列,\(O(1)\)计算逆序对个数
计算时可以先将两段对齐然后计算
若\(a^{-1}<=1000\),\(a* i+b+1=f_i\),移项得\(i=(f_i-1)* a^{-1}-b* a^{-1}\)
以 \(f_i-1\)为下标,此时是\(O(a^{-1})\)段公差为\(a^{-1}\)的的等差数列,可以\(O((a^{-1})^2)\)解决
需要注意的是第二种情况并非所有元素都是合法的,注意去掉\(i>n\)和\(i=0\)的情况
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+11;
struct dc_{
int a,len;
}xl[N];
int num;
int n,mod,a,b;
inline int read()
{
int s=0;
char ch=getchar();
while(ch>'9'||ch<'0') ch=getchar();
while(ch>='0'&&ch<='9')
{
s=(s<<1)+(s<<3)+(ch^48);
ch=getchar();
}
return s;
}
inline int min_(int x,int y){return x>y?y:x;}
inline long long js(int x,int k){return 1ll*(x+(x-k+1))*k/2;}
int fm(int x,int y)
{
int ans=1;
while(y)
{
if(y&1) ans=ans*x%mod;
y>>=1;
x=x*x%mod;
}
return ans;
}
signed main()
{
n=read();
mod=read();
a=read();
b=read();
int ans=0;
if(a>1000)
{
a=fm(a,mod-2),b=(mod-b*a%mod)%mod;
num=0;
int ed=0;
int lastlen=0;
for(;;)
{
if(lastlen>=n) break;
if(b>n)
{
ed=(mod-b)/a+1;
(b+=(ed*a))%=mod;
continue;
}
++num;
xl[num].a=b;
xl[num].len=(n-b)/a+1;
if(lastlen+xl[num].len>=n) {xl[num].len=n-lastlen;break;}
lastlen+=xl[num].len;
(b+=xl[num].len*a)%=mod;
}
}
else
{
b+=a;
num=0;
int lastlen=0;
for(;;)
{
++num;
xl[num].a=b%mod;
xl[num].len=(mod-xl[num].a-1)/a+1;
if(lastlen+xl[num].len>=n) {xl[num].len=n-lastlen;break;}
b=(b+a*((mod-xl[num].a-1)/a+1))%mod;
lastlen+=xl[num].len;
}
}
for(int a1,i=1;i<=num;++i)
{
a1=(xl[i].len-1)*a+xl[i].a;
for(int a2,qs,len,j=i+1;j<=num;++j)
{
a2=(xl[j].len-1)*a+xl[j].a;
if(a1<=xl[j].a) continue;
else if(a2<xl[i].a){ans+=1ll*xl[i].len*xl[j].len;continue;}
else if(a2==xl[i].a){ans+=1ll*xl[i].len*xl[j].len-1;continue;}
if(xl[i].a>xl[j].a)
{
qs=ceil((double)(xl[i].a-xl[j].a)/a);
len=min_(xl[i].len,xl[j].len-qs+1);
ans+=1ll*(qs-1)*xl[i].len+js(xl[i].len,len);
}
else
{
qs=(xl[j].a-xl[i].a)/a+1;
len=min_(xl[i].len-qs,xl[j].len);
ans+=js(xl[i].len-qs,len);
}
}
}
cout<<ans<<endl;
return 0;
}