线段树掌握的很差,打算从头从最简单的开始刷一波, 嗯。。就从这个题开始吧!
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <cstdlib>
#include <algorithm>
const int maxn = +;
using namespace std;
int a[maxn], n;
struct line
{
int l, r, val;
}tr[maxn<<]; void build(int o, int l, int r)
{
tr[o].l = l; tr[o].r = r;
if(l == r)
{
tr[o].val = a[l];
return;
}
int mid = (l+r)>>;
build(*o, l, mid);
build(*o+, mid+, r);
tr[o].val = tr[*o].val+tr[*o+].val;
}
int query(int o, int l, int r)
{
if(tr[o].l==l && tr[o].r==r)
return tr[o].val;
int mid = (tr[o].l+tr[o].r)>>;
if(r <= mid) query(*o, l, r);
else if(l > mid) query(*o+, l, r);
else
return (query(*o, l, mid)+query(*o+, mid+, r));
}
void update(int o, int p, int add)
{
if(tr[o].l==tr[o].r&&tr[o].l==p)
{
tr[o].val += add;
return;
}
int mid = (tr[o].l+tr[o].r)>>;
if(p<=mid) update(*o, p, add);
else update(*o+, p, add);
tr[o].val = tr[*o].val+tr[*o+].val;
}
int main()
{
int t, i, ca = ;
int p, add, l, r;
char s[];
scanf("%d", &t);
while(t--)
{
scanf("%d", &n);
for(i = ; i <= n; i++)
scanf("%d", &a[i]);
printf("Case %d:\n", ca++); build(, , n);
while(~scanf("%s", s))
{
if(strcmp(s, "End")==) break;
if(s[]=='Q')
{
scanf("%d%d", &l, &r);
printf("%d\n", query(, l, r));
}
if(s[]=='A')
{
scanf("%d%d", &p, &add);
update(, p, add);
}
if(s[]=='S')
{
scanf("%d%d", &p, &add);
update(, p, -add);
}
}
}
return ;
}
注释的代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <cstdlib>
#include <algorithm>
const int maxn = +;
using namespace std;
int a[maxn], n;
struct line
{
int l, r, val; //val表示该区间的和
}tr[maxn<<]; void build(int o, int l, int r) //o代表当前节点编号
{
tr[o].l = l; tr[o].r = r;
if(l == r)
{
tr[o].val = a[l];
return;
}
int mid = (l+r)>>;
build(*o, l, mid);
build(*o+, mid+, r);
tr[o].val = tr[*o].val+tr[*o+].val; //建树把值从下往上加起来
}
int query(int o, int l, int r) //求l到r的和
{
if(tr[o].l==l && tr[o].r==r) //节点的区间吻合返回
return tr[o].val;
int mid = (tr[o].l+tr[o].r)>>;
if(r <= mid) query(*o, l, r);
else if(l > mid) query(*o+, l, r);
else
return (query(*o, l, mid)+query(*o+, mid+, r)); //横跨了区间
}
void update(int o, int p, int add) //对p节点增加add
{
if(tr[o].l==tr[o].r&&tr[o].l==p)
{
tr[o].val += add;
return;
}
int mid = (tr[o].l+tr[o].r)>>;
if(p<=mid) update(*o, p, add);
else update(*o+, p, add);
tr[o].val = tr[*o].val+tr[*o+].val; //找到值以后的更新
}
int main()
{
int t, i, ca = ;
int p, add, l, r;
char s[];
scanf("%d", &t);
while(t--)
{
scanf("%d", &n);
for(i = ; i <= n; i++)
scanf("%d", &a[i]);
printf("Case %d:\n", ca++); build(, , n);
while(~scanf("%s", s))
{
if(strcmp(s, "End")==) break;
if(s[]=='Q')
{
scanf("%d%d", &l, &r);
printf("%d\n", query(, l, r));
}
if(s[]=='A')
{
scanf("%d%d", &p, &add);
update(, p, add);
}
if(s[]=='S')
{
scanf("%d%d", &p, &add);
update(, p, -add);
}
}
}
return ;
}