树状数组求全局第 k 大 中求解了区间第 k 大的特殊情况——全局第 k 大。现利用线段树将其拓展到区间第 k 大。

下面中涉及到的字母量及其含义:

字母 含义
nn 数组长度
mm 询问个数
TT 离散化后映射到的值域(1Tn1 \le T \le n
aa 对原数据离散化后所对应的离散值数组

相同的第一步

树状数组求全局第 k 大 里,我们将数组中的数通过离散化映射到 1 ~ n 上,方便对值域建立线段树,区间第 k 大也一样。

std::vector<ll> a(n + 1), b;
a[0] = 0;
for (int i = 1; i <= n; ++i) {
    scanf("%lld", &a[i]);
    b.push_back(a[i]);
}

std::sort(b.begin(), b.end());
b.erase(std::unique(b.begin(), b.end()), b.end());
for (int i = 1; i <= n; ++i) {
    a[i] = std::lower_bound(b.begin(), b.end(), a[i]) - b.begin() + 1;
}

主要思想

在可持久化权值线段树中,我们抛弃了传统线段树节点的左右儿子各代表一段更小的区间的写法,转而记录每一个新增节点的编号值,使它指向左右儿子。

回到区间第 k 大的问题上来。具体来说,先对值域建立一颗可持久化权值线段树,在读入每一个 aia_i 的值时,与全局的做法类似,将第 aia_i 个位置增加 1。

与全局不同的是,我们需要先将之前的版本复制一遍,然后再在新的版本上单点增加。这样操作之后对于每一颗以 rooti(i[1,n])root_i (i \in [1, n]) 为根的线段树,区间 [L,R][L, R] 恰好保存了前 i 个数中有多少个数在 [L,R][L, R] 区间内。

考虑一组询问 li,ri,kil_i, r_i, k_i,显然 rootli\Large root_{l_i}rootri\Large root_{r_i} 对于值域的划分是相同的。可以这样理解,rootiroot_i 即表示对前 i 个数开桶计数的结果,因此,对于值域 [li,ri][l_i, r_i],线段树节点上 rootrirootli\Large root_{r_i} - \Large root_{l_i} 的结果为 lil_irir_i 在坐标轴上点分布的个数,即其满足可减性。综上,对于一组询问 li,ri,kil_i, r_i, k_i,我们直接在权值线段树上二分,计算 rootli\Large root_{l_i}rootri\Large root_{r_i} 的左子结点的贡献之差 c,如果 c 不大于 k,那么第 k 大数一定在左子树内,进入左子结点;否则则进入右子节点。

int kth(int p, int q, int l, int r, int k) {
    if (l == r) {
        return l;
    }
    int C = t[t[q].l].cnt - t[t[p].l].cnt;
    int mid = l + r >> 1;
    if (k <= C) {
        return kth(t[p].l, t[q].l, l, mid, k);
    } else {
        return kth(t[p].r, t[q].r, mid + 1, r, k - C);
    }
}

// 调用
seg s(n);
S.kth(root[l - 1], root[r], 1, t, k);

这样的算法(数据结构),我们称之为主席树。

一些细节

  • 我们不可能每次都把上一颗线段树复制一遍再操作,一个好的处理方法是只增加每次修改所影响的链,这样每次单点增操作所增加的空间复杂度为 O(logT)\mathrm O (\log T), 总体的空间复杂度为 O(n+T+(n+m)logT)=O((n+m)logT)\mathrm O (n + T + (n + m) \log T) = \mathrm O((n + m) \log T)。实际上开空间常用 O(ωn)\mathrm O(\omega \cdot n),其中 ω=32\omega = 32

  • 同样地,我们也不可能不可能再复制上一次的链,所以直接利用权值线段树的性质,在第 i - 1 个版本上修改即可。注意为了避免 RE,将第 0 个版本设置为全 0,表示前 0 个数在坐标轴上无点分布。

代码

#include <bits/stdc++.h>

using ll = long long;

struct seg {
    struct node {
        int cnt;
        int l, r;
    };

    int cnt = 0;
    std::vector<node> t;

    seg(int n) {
        t.assign((n << 5) + 7, {0, 0, 0});
    }

    int build(int l, int r) {
        int p = ++cnt;
        if (l == r) {
            return p;
        }
        int mid = l + r >> 1;
        t[p].l = build(l, mid);
        t[p].r = build(mid + 1, r);
        return p;
    }

    int insert(int now, int l, int r, int x) {
        int p = ++cnt;
        t[p] = t[now];
        ++t[p].cnt;
        if (l == r) {
            return p;
        }
        int mid = l + r >> 1;
        if (x <= mid) {
            t[p].l = insert(t[now].l, l, mid, x);
        } else {
            t[p].r = insert(t[now].r, mid + 1, r, x);
        }
        return p;
    }

    int kth(int p, int q, int l, int r, int k) {
        if (l == r) {
            return l;
        }
        int C = t[t[q].l].cnt - t[t[p].l].cnt;
        int mid = l + r >> 1;
        if (k <= C) {
            return kth(t[p].l, t[q].l, l, mid, k);
        } else {
            return kth(t[p].r, t[q].r, mid + 1, r, k - C);
        }
    }
};

auto main()->int {
    std::cin.tie(nullptr)->sync_with_stdio(false);
    
    // freopen("log.txt", "w", stderr);

    int n, m;
    scanf("%d %d", &n, &m);

    std::vector<ll> a(n + 1), b;
    a[0] = 0;
    for (int i = 1; i <= n; ++i) {
        scanf("%lld", &a[i]);
        b.push_back(a[i]);
    }

    std::sort(b.begin(), b.end());
    b.erase(std::unique(b.begin(), b.end()), b.end());
    for (int i = 1; i <= n; ++i) {
        a[i] = std::lower_bound(b.begin(), b.end(), a[i]) - b.begin() + 1;
    }

    int t = b.size();
    seg S(n); // 注意是 n 不是 t, 考虑所有数都相同的情况
    std::vector<int> root(n + 1, 0);
    root[0] = S.build(1, n);
    for (int i = 1; i <= n; ++i) {
        root[i] = S.insert(root[i - 1], 1, t, a[i]);
    }

    for (int i = 1, l, r, k; i <= m; ++i) {
        scanf("%d %d %d", &l, &r, &k);
        int res = S.kth(root[l - 1], root[r], 1, t, k);
        res = res > 0 ? res - 1 : res;
        printf("%lld\n", b[res]);
    }
    
    return 0;
}

附:gen.cpp

#include <bits/stdc++.h>

using ll = long long;

std::random_device rd;
std::mt19937 rng(rd());

constexpr int p = 1e9;
constexpr int q = 2e5;

int main() {
    int n = rng() % q + 1, m = rng() % q + 1;
    printf("%d %d\n", n, m);
    for (int i = 1; i <= n; ++i) {
        printf("%d ", rng() % (p << 1) - p);
    }
    printf("\n");

    for (int i = 1; i <= m; ++i) {
        int l = rng() % n + 1, r = rng() % n + 1;
        if (l > r) {
            std::swap(l, r);
        }
        printf("%d %d %d\n", l, r, rng() % (r - l + 1) + 1);
    }

    return 0;
}