FOJ Moving Points 题解

题目描述

在一个一维坐标轴上给定n(2<=n<=2×105)个点的初始坐标x1,x2,x3,x4……xn(1<=xi<=108),每个点有一个初速度v1,v2,v3,v4……vn(-108<=vi<=108)所有点同时开始向右移动,在运动过程中每两个点都会存在一个最短距离,问所有点两两之间的最短距离和是多少。

输入

第一行为n
第二行n个数字,表示n个点的位置xi
第三行n个数字,表示n个点的速度vi

输出

所有点两两之间出现过的最短距离和

样例输入

3
1 3 2
-100 2 3

样例输出

3

题解

可以发现关于任意两点i, j满足xi<xj,如果vi<=vj,则最短距离为(xj-xi);如果vi>vj,则最短距离为0。
于是我们可以将所有点按x排序,每次枚举点i前面的所有点,对于速度小于vi的点,将答案加上二者间的距离。
然而此法时间复杂度为O(N2),显然不满足要求。于是考虑权值线段树。
先对v离散化。对于线段树上的节点[L,R],记录v在[L,R]区间内的点的个数tot,以及这些点到原点的距离总和sum。那么对于点i,线段树上[1,vi-1]区间内的tot即为速度小于vi的点的个数,sum即为速度小于vi的点到原点的距离总和。
考虑三点i,j,k满足xi>xj,xi>xk,则xi-xj+xi-xk=xi * 2-(xj+xk)。于是推广得到i与速度小于vi的所有点的距离总和为xi * tot-sum。
总体时间复杂度为O(NlogN)。

代码

#include<cstdio>
#include<algorithm>
using namespace std;
const int N = 2e5 + 1;
struct node {
	int x, v;
	bool operator <(node other) const {
		return x < other.x;
	}
} a[N];
struct Pair {
	int tot;
	long long sum;
	Pair operator +(Pair other) const {
		return (Pair){tot + other.tot, sum + other.sum};
	}
};
struct tree {
	int l, r;
	Pair dat;
} t[4 * N];
int n;
int mpv[N], m;
int ask(int x) {
	return lower_bound(mpv + 1, mpv + m + 1, x) - mpv;
}
void Build(int p, int l, int r) {
	t[p].l = l; t[p].r = r;
	if (l == r) {
		t[p].dat = (Pair){0, 0}; return ;
	}
	int mid = (l + r) >> 1;
	Build(p * 2, l, mid); Build(p * 2 + 1, mid + 1, r);
}
void change(int p, int x, int v1, int v2) {
	if (t[p].l == t[p].r) {
		t[p].dat = t[p].dat + (Pair){v1, v2}; return ;
	}
	int mid = (t[p].l + t[p].r) >> 1;
	if (x <= mid) change(p * 2, x, v1, v2);
	else change(p * 2 + 1, x, v1, v2);
	t[p].dat = t[p * 2].dat + t[p * 2 + 1].dat;
}
Pair query(int p, int l, int r) {
	if (l <= t[p].l && t[p].r <= r) return t[p].dat;
	int mid = (t[p].l + t[p].r) >> 1;
	Pair ans = (Pair){0, 0};
	if (l <= mid) ans = ans + query(p * 2, l, r);
	if (r > mid) ans = ans + query(p * 2 + 1, l, r);
	return ans;
}
int main() {
	long long ans = 0;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) scanf("%d", &a[i].x);
	for (int i = 1; i <= n; i++) {
		scanf("%d", &a[i].v); mpv[i] = a[i].v;
	}
	sort(a + 1, a + n + 1);
	sort(mpv + 1, mpv + n + 1);
	m = unique(mpv + 1, mpv + n + 1) - (mpv + 1);
	Build(1, 1, m);
	for (int i = 1; i <= n; i++) {
		Pair res = query(1, 1, ask(a[i].v));
		ans += (long long)a[i].x * res.tot - res.sum;
		change(1, ask(a[i].v), 1, a[i].x);
	}
	printf("%lld\n", ans);
	return 0;
}

上一篇:Mac下java编译乱码(适用于maven , ant)


下一篇:FOJ Moving Points 题解