2021年度训练联盟热身训练赛第三场 L——Traveling Merchant(2*7线段树+查询时的区间合并)

题目链接

学长说也可以用倍增、或分块来做。我还是用思路最简单的线段树搞搞吧


题目大意:
1、n个城市,一个商人在一个城市买糖果,在另一个城市卖。(一个点买入,另一个点卖出)
2、商人只能一步一天的走,起点为x,终点为y,(可能x大于y),每个城市只能去一次,且只能顺序走,步长为1。
3、每个城市的糖果价格初值为v,变化值为d,周一和周天为 v,周二和周六为 v + d,周三和周五为 v + 2 * d,周四为 v + 3 * d
4、出发当天(即起点)固定为周一。
5、求最大能赚多少。


题目转化:

1、线段树求区间内有方向的差值最大
2、2个7线段树(分正反走)


注意点:

1、维护区间最大,最小,区间左到右的最大差值 或 右到左的最大差值。
2、得确定当前端点为周几。
3、查询时有区间合并。


下面的代码是我一开始比赛时一直错的代码,只过了20。没考虑到区间合并

#define ls		i<<1
#define rs		i<<1|1
#define mid     (l+r>>1)
#define lson	ls,l,mid
#define rson	rs,mid+1,r
const int maxn=2e5+7;
ll n,m,a[maxn][10],p;
struct node{	ll mx,mi,sum;	};
node t1[maxn<<2][10],t2[maxn<<2][10];
void pushup1(int i){
	t1[i][p].sum=t1[rs][p].mx-t1[ls][p].mi;
	t1[i][p].sum=max(t1[i][p].sum,max(t1[ls][p].sum,t1[rs][p].sum));
	t1[i][p].mi=min(t1[ls][p].mi,t1[rs][p].mi);
	t1[i][p].mx=max(t1[ls][p].mx,t1[rs][p].mx);
}
void pushup2(int i){
	t2[i][p].sum=t2[ls][p].mx-t2[rs][p].mi;
	t2[i][p].sum=max(t2[i][p].sum,max(t2[ls][p].sum,t2[rs][p].sum));
	t2[i][p].mi=min(t2[ls][p].mi,t2[rs][p].mi);
	t2[i][p].mx=max(t2[ls][p].mx,t2[rs][p].mx);
}
void build1(int i,int l,int r){
	if(l==r){
		int k=((l-1)%7+p)%7;
		if(!k)	k=7;
		t1[i][p].mx=t1[i][p].mi=a[l][k];
//		cout<<l<<' '<<k<<' '<<a[l][k]<<endl;
		return;
	}
	build1(lson);	build1(rson);
	pushup1(i);
//	cout<<l<<'~'<<r<<' '<<t1[i][p].sum<<endl;
}
void build2(int i,int l,int r){
	if(l==r){
		int k=((n-l)%7+p)%7;
		if(!k)	k=7;
		t2[i][p].mx=t2[i][p].mi=a[l][k];
//		cout<<l<<' '<<k<<' '<<a[l][k]<<endl;
		return;
	}
	build2(rson);	build2(lson);
	pushup2(i);
//	cout<<l<<'~'<<r<<' '<<t2[i][p].sum<<endl;
}
ll query1(int i,int l,int r,int x,int y){
	if(l>y||r<x)	return 0;
	if(x<=l&&r<=y)	return t1[i][p].sum;
	return max(query1(lson,x,y),query1(rson,x,y));
}
ll query2(int i,int l,int r,int x,int y){
	if(l>y||r<x)	return 0;
	if(x<=l&&r<=y)	return t2[i][p].sum;
	return max(query2(lson,x,y),query2(rson,x,y));
}
int main(){
	n=read();
	for(int i=1;i<=n;i++){
		ll x,y;	x=read();	y=read();
		a[i][1]=a[i][7]=x;
		a[i][2]=a[i][6]=x+y;
		a[i][3]=a[i][5]=x+2*y;
		a[i][4]=x+3*y;
	}
	for(int i=1;i<=7;i++){	p=i;	build1(1,1,n);	}	//第1个位置为星期i 
	for(int i=1;i<=7;i++){	p=i;	build2(1,1,n);	}	//第n个位置为星期i
	m=read();
	while(m--){
		int x,y;	x=read();	y=read();
		if(x<y){
			//找到第1个位置为周几
			int k=x,j=1;
			while(k%7!=1){
				j--;	k--;
				if(!j)	j=7;
			}
			p=j;
			printf("%lld\n",query1(1,1,n,x,y));
		}
		else{
			//找到第n个位置为周几 
			int k=x,j=1;
			while((n-k)%7!=0){
				j--;	k++;
				if(!j)	j=7;
			}
			p=j;
			printf("%lld\n",query2(1,1,n,y,x));
		}
	}
}

其中这个查询操作有问题。

ll query1(int i,int l,int r,int x,int y){
	if(l>y||r<x)	return 0;
	if(x<=l&&r<=y)	return t1[i][p].sum;
	return max(query1(lson,x,y),query1(rson,x,y));
}

不应该return max(query1(lson,x,y),query1(rson,x,y)); 。少考虑了。左边查询区间如果有(相对的)最小值(该最小值比右区间最小值还小),用右边最大值减去左边最小值能比答案大。

查询 [3,9]
左区间 [3,7) , sum=4 , mx=5 ,mi=1
右区间 [7,9) , sum=6 , mx=10, mi=4

如果按照我原来错误的代码,出来的是 6
但正确是 10-1==9

卡这个细节卡了半天。
可以将查询函数改成传递结构体(题解就是这个做法),也可以设置全局变量,更简短。


修改后的ac代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf=0x3f3f3f3f;
ll read(){
    ll x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<1)+(x<<3)+ch-'0';
        ch=getchar();
    }
    return x*f;
}
#define ls		i<<1
#define rs		i<<1|1
#define mid     (l+r>>1)
#define lson	ls,l,mid
#define rson	rs,mid+1,r
const int maxn=1e5+7;
ll n,m,a[maxn][10],p;
struct node{	ll mx,mi,sum;	};
node t1[maxn<<2][10],t2[maxn<<2][10];
void pushup1(int i){
	t1[i][p].sum=t1[rs][p].mx-t1[ls][p].mi;
	t1[i][p].sum=max(t1[i][p].sum,max(t1[ls][p].sum,t1[rs][p].sum));
	t1[i][p].mi=min(t1[ls][p].mi,t1[rs][p].mi);
	t1[i][p].mx=max(t1[ls][p].mx,t1[rs][p].mx);
}
void pushup2(int i){
	t2[i][p].sum=t2[ls][p].mx-t2[rs][p].mi;
	t2[i][p].sum=max(t2[i][p].sum,max(t2[ls][p].sum,t2[rs][p].sum));
	t2[i][p].mi=min(t2[ls][p].mi,t2[rs][p].mi);
	t2[i][p].mx=max(t2[ls][p].mx,t2[rs][p].mx);
}
void build1(int i,int l,int r){
	if(l==r){
		int k=((l-1)%7+p)%7;
		if(!k)	k=7;
		t1[i][p].mx=t1[i][p].mi=a[l][k];
		return;
	}
	build1(lson);	build1(rson);
	pushup1(i);
}
void build2(int i,int l,int r){
	if(l==r){
		int k=((n-l)%7+p)%7;
		if(!k)	k=7;
		t2[i][p].mx=t2[i][p].mi=a[l][k];
		return;
	}
	build2(rson);	build2(lson);
	pushup2(i);
}
ll t;
ll query1(int i,int l,int r,int x,int y){
	if(l>y||r<x)	return 0;
	if(x<=l&&r<=y){
		ll temp=t1[i][p].mx-t;
		t=min(t,t1[i][p].mi);
		return max(t1[i][p].sum,temp);
	}
	return max(query1(lson,x,y),query1(rson,x,y));
}
ll query2(int i,int l,int r,int x,int y){
	if(l>y||r<x)	return 0;
	if(x<=l&&r<=y){
		ll temp=t2[i][p].mx-t;
		t=min(t,t2[i][p].mi);
		return max(t2[i][p].sum,temp);
	}
	return max(query2(rson,x,y),query2(lson,x,y));
}
int main(){
	n=read();
	for(int i=1;i<=n;i++){
		ll x,y;	x=read();	y=read();
		a[i][1]=a[i][7]=x;
		a[i][2]=a[i][6]=x+y;
		a[i][3]=a[i][5]=x+2*y;
		a[i][4]=x+3*y;
	}
	for(int i=1;i<=7;i++){	p=i;	build1(1,1,n);	}	//第1个位置为星期i 
	for(int i=1;i<=7;i++){	p=i;	build2(1,1,n);	}	//第n个位置为星期i
	m=read();
	while(m--){
		int x,y;	x=read();	y=read();	t=inf;
		if(x<y){
			//找到第1个位置为周几
			int k=x,j=1;
			while(k%7!=1){
				j--;	k--;
				if(!j)	j=7;
			}
			p=j;
			printf("%lld\n",query1(1,1,n,x,y));
		}
		else{
			//找到第n个位置为周几 
			int k=x,j=1;
			while((n-k)%7!=0){
				j--;	k++;
				if(!j)	j=7;
			}
			p=j;
			printf("%lld\n",query2(1,1,n,y,x));
		}
	}
}

题解给的代码

emmm,不是我的风格,不学。

#include <bits/stdc++.h>
#pragma GCC optiMinze(2)
using namespace std;
#define iny long long
const iny maxn = 1e5 + 10;
struct node
{
    int Max;
    int Min;
    int _Max;
    node(){
        Max=Min=_Max=0;
    }
};
node tree[15][maxn << 2];
int v[15][maxn], d[maxn], n, q, s;

inline void push_up(iny p)
{
    tree[s][p]._Max = max(tree[s][2 * p]._Max, max(tree[s][2 * p + 1]._Max, tree[s][2 * p + 1].Max - tree[s][2 * p].Min));
    tree[s][p].Max = max(tree[s][2 * p].Max, tree[s][2 * p + 1].Max);
    tree[s][p].Min = min(tree[s][2 * p].Min, tree[s][2 * p + 1].Min);
}

void build(iny p, iny l, iny r)
{
    if (l == r)
    {
        tree[s][p]._Max = 0;
        tree[s][p].Max = v[s][l];
        tree[s][p].Min = v[s][l];
        return;
    }
    iny Mind = l + r >> 1;
    build(p << 1, l, Mind);
    build((p << 1) + 1, Mind + 1, r);
    push_up(p);
}

node query(iny l, iny r, iny p, iny pl, iny pr)
{
    if (l <= pl && r >= pr)
    {
        return tree[s][p];
    }
    iny Mind = pl + pr >> 1;
    node ans1, ans2, ans3;
    if (l <= Mind)
    {
        ans2 = query(l, r, p << 1, pl, Mind);
        ans1 = ans2;
    }
    if (r > Mind)
    {
        ans3 = query(l, r, (p << 1) + 1, Mind + 1, pr);
        if (l <= Mind)
        {
            ans1._Max = max(ans3.Max - ans2.Min, max(ans3._Max, ans2._Max));
            ans1.Min = min(ans2.Min, ans3.Min);
            ans1.Max = max(ans2.Max, ans3.Max);
        }
        else
            ans1 = ans3;
    }
    return ans1;
}

int main()
{
    scanf("%d",&n);
    for (iny i = 1; i <= n; i++)
    {
        scanf("%d%d", &v[0][i], &d[i]);
    }
    for (iny i = 0; i <= 6; i++)
    {
        iny nw = i;
        for (iny j = 1; j <= n; j++)
        {
            if (nw == 0 || nw == 6)
            {
                v[i + 1][j] = v[0][j];
            }
            if (nw == 1 || nw == 5)
            {
                v[i + 1][j] = v[0][j] + d[j];
            }
            if (nw == 2 || nw == 4)
            {
                v[i + 1][j] = v[0][j] + 2 * d[j];
            }
            if (nw == 3)
            {
                v[i + 1][j] = v[0][j] + 3 * d[j];
            }
            nw++;
            nw %= 7;
        }
    }
    for (s = 1; s <= 7; s++)
        build(1, 1, n);
    for (iny i = 1; i <= n / 2; i++)
    {
        swap(v[0][i], v[0][n - i + 1]);
        swap(d[i], d[n - i + 1]);
    }
    for (iny i = 0; i <= 6; i++)
    {
        iny nw = i;
        for (iny j = 1; j <= n; j++)
        {
            if (nw == 0 || nw == 6)
            {
                v[i + 8][j] = v[0][j];
            }
            if (nw == 1 || nw == 5)
            {
                v[i + 8][j] = v[0][j] + d[j];
            }
            if (nw == 2 || nw == 4)
            {
                v[i + 8][j] = v[0][j] + 2 * d[j];
            }
            if (nw == 3)
            {
                v[i + 8][j] = v[0][j] + 3 * d[j];
            }
            nw++;
            nw %= 7;
        }
    }
    for (s = 8; s <= 14; s++)
        build(1, 1, n);
    scanf("%d", &q);
    while (q--)
    {
        iny l, r;
        cin >> l >> r;
        if (l <= r)
        {
            s = (700000 + 2 - l) % 7;
            if (s == 0)
                s = 7;
        }
        else
        {
            s = (700000 + l - n + 1) % 7;
            if (s == 0)
                s = 7;
            s += 7;
            l = n - l + 1;
            r = n - r + 1;
        }
        node ans = query(l, r, 1, 1, n);
        printf("%d\n", ans._Max);
    }
    return 0;
}

学长的倍增方法

没研究,先附一下。

#pragma GCC optimize(3)
#include <bits/stdc++.h>
#include<stdio.h>
#include<queue>
#include<algorithm>
#include<string.h>
#include<iostream>
#define debug(x) cout<<#x<<":"<<x<<endl;
#define dl(x) printf("%lld\n",x);
#define di(x) printf("%d\n",x);
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const ll INF= 1e17+7;
const ll maxn = 1e5+700;
const int M = 1e6+8;
const ll mod= 1e9+7;
const double eps = 1e-9;
const double PI = acos(-1);
template<typename T>inline void read(T &a){char c=getchar();T x=0,f=1;while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+c-'0';c=getchar();}a=f*x;}
ll n,m,p;
int a[maxn],d[maxn];
int val[7][maxn];
int mx[7][maxn][18];
int mi[7][maxn][18];
int mxv[7][maxn][18],miv[7][maxn][18];
int lg[maxn];
void Restart(int id){
	for(int i=1;i<=n;i++){
		mx[id][i][0] = mi[id][i][0] = val[id][i];
		mxv[id][i][0] = miv[id][i][0] = 0;
	}
	///st[i][j]=max(st[i][j-1],st[i+(1<<(j-1))][j-1]);
	for(int k=1;k<20;k++){
		for(int i=1;i+(1<<k)-1<=n;i++){
			mx[id][i][k] = max(mx[id][i][k-1],mx[id][i+(1<<(k-1))][k-1]);
			mi[id][i][k] = min(mi[id][i][k-1],mi[id][i+(1<<(k-1))][k-1]);

			mxv[id][i][k] = max(mxv[id][i][k-1],mxv[id][i+(1<<(k-1))][k-1]);
			miv[id][i][k] = min(miv[id][i][k-1],miv[id][i+(1<<(k-1))][k-1]);

			mxv[id][i][k] = max(mxv[id][i][k],mx[id][i+(1<<(k-1))][k-1]-mi[id][i][k-1]);
			miv[id][i][k] = min(miv[id][i][k],mi[id][i+(1<<(k-1))][k-1]-mx[id][i][k-1]);
		}
	}
}
void __inint(int f){
	for(int i=1;i<=n;i++) lg[i]=lg[i/2]+1;
	for(int k=0;k<7;k++){
		int s = k;
		for(int i=1;i<=n;i++){
			if(s%7 == 0 || s%7 == 1) val[k][i] = a[i];
			if(s%7 == 2 || s%7 == 6) val[k][i] = a[i]+d[i];
			if(s%7 == 3 || s%7 == 5) val[k][i] = a[i]+2*d[i];
			if(s%7 == 4) val[k][i] = a[i]+3*d[i];
			if(!f) s = (s+1)%7;
			else s = (s+6)%7;
		}
	}
	/*for(int k=0;k<7;k++){
		printf("%d:\n",k);
		for(int i=1;i<=n;i++) 
			printf("%lld ",val[k][i]);
		printf("\n");
	}*/
	for(int k=0;k<7;k++) Restart(k);
}
ll getMax(int id,int x,int y){
  int len=lg[y-x+1]-1;
  return max(mx[id][x][len],mx[id][y-(1<<len)+1][len]);
}
ll getMin(int id,int x,int y){
  int len=lg[y-x+1]-1;
  return min(mi[id][x][len],mi[id][y-(1<<len)+1][len]);
}
ll getMaxVal(int id,int x,int y){
	if(x == y) return 0;
	ll res = 0;
	int len=lg[y-x+1]-1;
	res = max(mxv[id][x][len],mxv[id][y-(1<<len)+1][len]);
	int mid = (y+x)/2;
	res = max(res,getMax(id,mid+1,y)-getMin(id,x,mid));
	return res;
}
ll getMinVal(int id,int x,int y){
	if(x == y) return 0;
	ll res = INF;
	int len=lg[y-x+1]-1;
	res = min(miv[id][x][len],miv[id][y-(1<<len)+1][len]);
	int mid = (y+x)/2;
	res = min(res,getMin(id,mid+1,y)-getMax(id,x,mid));
	return res;
}
int res[maxn];
vector< pair< pair<int,int>,int> >v,g;
int main(){
	read(n);
  for(int i=1;i<=n;i++){
  	read(a[i]);
  	read(d[i]);
  }
  read(m);
  for(int i=1;i<=m;i++){
  	int x,y;read(x);read(y);
  	if(x == y) res[i] = 0;
  	else if(x>y) v.push_back({{x,y},i});
  	else g.push_back({{x,y},i});
  }
  __inint(0);
  for(auto t:g){
  	int idx = t.second,x = t.first.first,y = t.first.second;
  	int id = (1-(x-1)%7+7)%7;
  	res[idx] = max(getMaxVal(id,x,y),0ll);
  }
  __inint(1);
  for(auto t:v){
  	int idx = t.second,x = t.first.first,y = t.first.second;
  	int id = (1+(x-1)%7+7)%7;
  	res[idx] = max(-getMinVal(id,y,x),0ll);
  }
  for(int i=1;i<=m;i++)
  	printf("%d\n",res[i]);
  return 0;
}
/***

***/
上一篇:【练习-9】最大子序列和==>最大子矩阵和(#若是)


下一篇:P2455 [SDOI2006]线性方程组