Treap

[mathjax]

\(\text{Treap}\)是一种通过维护堆和二叉搜索树两重性质来限制树高的平衡树。

每一个节点都被赋予了两个值,称作\(val\)和\(key\),分别按照BST和堆的性质维护:

  • 小根堆:每一个节点的\(key\)都小于等于他的两个子节点的\(key\)
  • 二叉树:每一个节点的\(key\)都比他左子的\(val\)大,比右子的\(val\)小

每一个节点的\(key\)是随机指定的,这样可以保证期望树高为\(\log n\)


维护堆的性质其实很简单,只要在每次插入完节点后循环不停地判断当前节点的\(key\)是否小于父节点的\(key\),如果小于则直接把当前节点旋转上去

void adjust(node* p) {
    while (p->father && p->father->key > p->key)
        rotate(p);
}

void insert(int val, node*& p = root, node* from = 0) {
    if (!p) {
        p = ALLOCATE;
        p->father = from;
        p->val = val;
        p->key = rand();
        p->cnt = p->size = 1;
        adjust(p);
    } else {
        p->size++;
        if (val < p->val)
            insert(val, p->lson, p);
        else if (val == p->val)
            p->cnt++;
        else /* val > p->val */
            insert(val, p->rson, p);
    }
}

\(\text{Treap}\)删除节点时先要把这个节点向下旋转到只有一个叶子节点/没有叶子节点时再直接把他的父节点和子节点链接,进而切掉这个节点

旋转时要注意选择\(key\)较小的那一个子节点进行旋转,以维持小根堆的性质

void destroy(node* p) {
    while (p->lson&&p->rson)
        if (p->lson->val < p->rson->val)
            rotate(p->lson);
        else
            rotate(p->rson);
    if (p == root)
        root = (p->lson ? p->lson : p->rson);
    connect(p->father, (p->lson ? p->lson : p->rson), tell(p));
}

合并两个\(\text{Treap}\)其实没有特殊的方法,只能启发式合并:暴力提取每一个较小的\(\text{Treap}\)中的节点,暴力插到大的那个中

其他地方和裸的\(\text{BST}\)就没有本质区别了(看看人家\(\text{Treap}\)多好写,再看看你\(\text{splay}\)

/*
 DOCUMENT NAME "20181021-luogu3369.cpp"
 CREATION DATE 2018-10-21
 SIGNATURE CODE_20181021_LUOGU3369
 COMMENT P3369 【模板】普通平衡树 / Treap
*/

#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <ctime>
#include <cstdio>
using namespace std;

const int MaxN = 100000 + 10;

template<typename IntType>
void read(IntType& val) {
    val = 0;
    int c;
    bool inv = false;
    while (!isdigit(c = getchar()))
        if (c == '-')
            inv = true;
    do {
        val = (val << 1) + (val << 3) + c - '0';
    } while (isdigit(c = getchar()));
    if (inv)
        val = -val;
}

struct node {
    int val, key;
    int cnt, size;
    node* lson, *rson;
    node* father;
};

node* root;
node mem[MaxN], *memtop = mem;
#define ALLOCATE (++memtop)

typedef int sontype;
const sontype lson = 0, rson = 1;
sontype tell(node* son) { return (!son->father || son->father->lson == son) ? lson : rson; }
node*& get(node* father, sontype type) { return type == lson ? father->lson : father->rson; }
void connect(node* father, node* son, sontype type) {
    if (father)
        get(father, type) = son;
    if (son)
        son->father = father;
}

int size(node* p) { return p ? p->size : 0; }
void update(node* p) {
    p->size = size(p->lson) + size(p->rson) + p->cnt;
}

void rotate(node* p) {
    if (!p->father)
        return;
    sontype t = tell(p);
    node* f = p->father, *b = get(p, 1 - t);
    connect(f->father, p, tell(f));
    connect(p, f, 1 - t);
    connect(f, b, t);
    update(f);
    update(p);
    if (!p->father)
        root = p;
}

void adjust(node* p) {
    while (p->father&&p->father->key > p->key)
        rotate(p);
}

void insert(int val, node*& p = root, node* from = 0) {
    if (!p) {
        p = ALLOCATE;
        p->father = from;
        p->val = val;
        p->key = rand();
        p->cnt = p->size = 1;
        if (from&&p->key < from->key)
            adjust(p);
    } else {
        p->size++;
        if (val < p->val)
            insert(val, p->lson, p);
        else if (val == p->val)
            p->cnt++;
        else /* val > p->val */
            insert(val, p->rson, p);
    }
}

void destroy(node* p) {
    while (p->lson&&p->rson)
        if (p->lson->val < p->rson->val)
            rotate(p->lson);
        else
            rotate(p->rson);
    if (p == root)
        root = (p->lson ? p->lson : p->rson);
    connect(p->father, (p->lson ? p->lson : p->rson), tell(p));
}


void erase(int val, node* p = root) {
    p->size--;
    if (val < p->val)
        erase(val, p->lson);
    else if (val == p->val) {
        p->cnt--;
        if (!p->cnt)
            destroy(p);
    } else /* val > p->val */
        erase(val, p->rson);
}

int getrank(int val, node* p = root) {
    if (!p)
        return 1;
    else if (val < p->val)
        return getrank(val, p->lson);
    else if (val == p->val)
        return size(p->lson) + 1;
    else /* val > p->val */
        return getrank(val, p->rson) + size(p->lson) + p->cnt;
}

int getkth(int k, node* p = root) {
    if (k <= size(p->lson))
        return getkth(k, p->lson);
    else if (k <= size(p->lson) + p->cnt)
        return p->val;
    else
        return getkth(k - size(p->lson) - p->cnt, p->rson);
}

int getprev(int val, node* p = root) {
    int ans = -1e8;
    while (p) {
        if (val <= p->val)
            p = p->lson;
        else { /* val > p->val */
            ans = max(ans, p->val);
            p = p->rson;
        }
    }
    return ans;
}

int getnext(int val, node* p = root) {
    int ans = 1e8;
    while (p) {
        if (val < p->val) {
            ans = min(ans, p->val);
            p = p->lson;
        } else /* val >= p->val */
            p = p->rson;
    }
    return ans;
}

int n;
int opt, x;


int main(int argc, char* argv[]) {

    srand(time(0));

    read(n);
    for (int i = 1; i <= n; i++) {
        read(opt); read(x);
        switch (opt) {
            case 1:
                insert(x);
                break;
            case 2:
                erase(x);
                break;
            case 3:
                printf("%d\n", getrank(x));
                break;
            case 4:
                printf("%d\n", getkth(x));
                break;
            case 5:
                printf("%d\n", getprev(x));
                break;
            case 6:
                printf("%d\n", getnext(x));
                break;
        }
    }

    return 0;
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注