P5057 [CQOI2006]简单题
link
由于题中涉及“区间”操作等字样,容易想到线段树解法。
考虑如何 pushup 和 pushdown。
一种正常的方法
按照题意模拟,设懒标记 tag 表示数组反转时给这个节点带来的影响。这里仅提供关键函数的代码。
void pushdown(int p) {
if (tr[p].tag) {
tr[p << 1].tag ^= 1; // 按照题意取反
tr[p << 1 | 1].tag ^= 1;
tr[p].tag = 0;
}
}
void update(int p, int l, int r) {
if (tr[p].l >= l && tr[p].r <= r) return tr[p].tag ^= 1, void();
pushdown(p);
int mid = tr[p].l + tr[p].r >> 1;
if (l <= mid) update(p << 1, l, r);
if (r > mid) update(p << 1 | 1, l, r);
}
int ask(int p, int x) {
if (tr[p].l == tr[p].r) return tr[p].tag ? tr[p].val ^= 1, tr[p].tag = 0, tr[p].val : tr[p].val;
pushdown(p);
int mid = tr[p].l + tr[p].r >> 1;
if (x <= mid) return ask(p << 1, x); else return ask(p << 1 | 1, x);
}
另外一种方法
让我们来模拟一组数据
4 3
1 1 4
1 1 3
1 2 4
操作 | 模拟 |
---|---|
1 1 4 | 将数列中 1~4 的数字全部取反 此时数列变为:1 1 1 1 |
1 1 3 | 将数列中 1~3 的数字全部取反 此时数列变为:0 0 0 1 |
1 2 4 | 将数列中 2~4 的数字全部取反 此时数列变为:0 1 1 0 |
此时如果将“取反”操作改为“+1”,则有:
操作 | 模拟 |
---|---|
1 1 4 | 将数列中 1~4 的数字全部 |
1 1 3 | 将数列中 1~3 的数字全部 |
1 2 4 | 将数列中 2~4 的数字全部 |
若我们将每一次“+1”操作后的数列对 2 取模,观察结果
操作 | 取反 | “+1” | “+1”取模 |
---|---|---|---|
1 1 4 | 1 1 1 1 | 1 1 1 1 | 1 1 1 1 |
1 1 3 | 0 0 0 1 | 2 2 2 1 | 0 0 0 1 |
1 2 4 | 0 1 1 0 | 2 2 2 1 | 0 1 1 0 |
我们可以惊奇地发现,取模后的结果与对原数列取反的结果相同。
证明:
若有一个数 ,其取反的结果为 0,即
设有数列
若对数列中 取反:
则变成
正确性显然
#include <bits/stdc++.h>
using namespace std;
#define debug(args...) fprintf(stderr, ##args)
using ll = long long;
const int N = 1e5 + 7;
struct segtree {
struct node {
int l, r;
int sum, add;
} tr[N << 2];
void pushup(int p) { tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum; }
void pushdown(int p) {
tr[p << 1].sum += tr[p].add * (tr[p << 1].r - tr[p << 1].l + 1);
tr[p << 1 | 1].sum += tr[p].add * (tr[p << 1 | 1].r - tr[p << 1 | 1].l + 1);
tr[p << 1].add += tr[p].add;
tr[p << 1 | 1].add += tr[p].add;
tr[p].add = 0;
}
void build(int p, int l, int r) {
tr[p].l = l, tr[p].r = r;
if (l == r) return tr[p].sum = 0, void();
int mid = l + r >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
pushup(p);
}
void update(int p, int l, int r) {
if (tr[p].l >= l && tr[p].r <= r) return tr[p].add += 1, tr[p].sum += tr[p].r - tr[p].l + 1, void();
pushdown(p);
int mid = tr[p].l + tr[p].r >> 1;
if (l <= mid) update(p << 1, l, r);
if (r > mid) update(p << 1 | 1, l, r);
pushup(p);
}
int query(int p, int l, int r) {
if (tr[p].l >= l && tr[p].r <= r) return tr[p].sum % 2;
pushdown(p);
int mid = tr[p].l + tr[p].r >> 1, ans = 0;
if (l <= mid) return ans += query(p << 1, l, r);
if (r > mid) return ans += query(p << 1 | 1, l, r);
return ans;
}
} st;
int n, m;
int main() {
scanf("%d %d", &n, &m);
st.build(1, 1, n);
for (int i = 1, x, y, z; i <= m; ++i) {
scanf("%d", &x);
if (x == 1) scanf("%d %d", &y, &z); else scanf("%d", &y);
if (x == 1) st.update(1, y, z);
if (x == 2) printf("%d\n", st.query(1, y, y));
}
return 0;
}
听说这题在当年是道紫题
3x经验:P2574 | P3870 | SP7259
听说三道题交一样的代码能过
code
#include <bits/stdc++.h>
using namespace std;
// 0. Enough array size? Enough array size? Enough array size? Integer overflow?
// 1. Think TWICE, Code ONCE!
// Are there any counterexamples to your algo?
// 2. Be careful about the BOUNDARIES!
// N=1? P=1? Something about 0?
// 3. Do not make STUPID MISTAKES!
// Time complexity? Memory usage? Precision error?
#define debug(args...) fprintf(stderr, ##args)
#define gc getchar
using ll = long long;
const int inf = 1 << 31 - 1;
const int N = 3e5 + 7;
int n, m;
struct segtree {
struct node {
int l, r;
int sum, add;
} tr[N << 2];
void pushup(int p) { tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum; }
void pushdown(int p, int x) {
if (tr[p].add) {
tr[p << 1].add ^= 1;
tr[p << 1 | 1].add ^= 1;
int mid = x >> 1;
tr[p << 1].sum = (x - mid) - tr[p << 1].sum;
tr[p << 1 | 1].sum = mid - tr[p << 1 | 1].sum;
tr[p].add = 0;
}
}
void build(int p, int l, int r) {
tr[p].l = l, tr[p].r = r;
if (l == r) return tr[p].sum = 0, void();
int mid = l + r >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
pushup(p);
}
void update(int p, int l, int r) {
pushdown(p, tr[p].r - tr[p].l + 1);
if (tr[p].l >= l && tr[p].r <= r) return tr[p].add ^= 1, tr[p].sum = tr[p].r - tr[p].l + 1 - tr[p].sum, void();
int mid = tr[p].l + tr[p].r >> 1;
if (l <= mid) update(p << 1, l, r);
if (r > mid) update(p << 1 | 1, l, r);
pushup(p);
}
int ask(int p, int l, int r) {
if (tr[p].l >= l && tr[p].r <= r) return tr[p].sum;
pushdown(p, tr[p].r - tr[p].l + 1);
int mid = tr[p].l + tr[p].r >> 1, ans = 0;
if (l <= mid) ans += ask(p << 1, l, r);
if (r > mid) ans += ask(p << 1 | 1, l, r);
return ans;
}
} st;
int main() {
scanf("%d %d", &n, &m);
st.build(1, 1, n);
for (int i = 1, op, l, r; i <= m; ++i) {
scanf("%d %d %d", &op, &l, &r);
if (op == 0) st.update(1, l, r); else printf("%d\n", st.ask(1, l, r));
}
return 0;
}