\(\text{Treap}\)是一种通过维护堆和二叉搜索树两重性质来限制树高的平衡树。
每一个节点都被赋予了两个值,称作\(val\)和\(key\),分别按照BST和堆的性质维护:
- 小根堆:每一个节点的\(key\)都小于等于他的两个子节点的\(key\)
- 二叉树:每一个节点的\(key\)都比他左子的\(val\)大,比右子的\(val\)小
每一个节点的\(key\)是随机指定的,这样可以保证期望树高为\(\log n\)
维护堆的性质其实很简单,只要在每次插入完节点后循环不停地判断当前节点的\(key\)是否小于父节点的\(key\),如果小于则直接把当前节点旋转上去
void adjust(node* p) { while (p->father && p->father->key > p->key) rotate(p); } void insert(int val, node*& p = root, node* from = 0) { if (!p) { p = ALLOCATE; p->father = from; p->val = val; p->key = rand(); p->cnt = p->size = 1; adjust(p); } else { p->size++; if (val < p->val) insert(val, p->lson, p); else if (val == p->val) p->cnt++; else /* val > p->val */ insert(val, p->rson, p); } }
\(\text{Treap}\)删除节点时先要把这个节点向下旋转到只有一个叶子节点/没有叶子节点时再直接把他的父节点和子节点链接,进而切掉这个节点
旋转时要注意选择\(key\)较小的那一个子节点进行旋转,以维持小根堆的性质
void destroy(node* p) { while (p->lson&&p->rson) if (p->lson->val < p->rson->val) rotate(p->lson); else rotate(p->rson); if (p == root) root = (p->lson ? p->lson : p->rson); connect(p->father, (p->lson ? p->lson : p->rson), tell(p)); }
合并两个\(\text{Treap}\)其实没有特殊的方法,只能启发式合并:暴力提取每一个较小的\(\text{Treap}\)中的节点,暴力插到大的那个中
其他地方和裸的\(\text{BST}\)就没有本质区别了(看看人家\(\text{Treap}\)多好写,再看看你\(\text{splay}\)
/* DOCUMENT NAME "20181021-luogu3369.cpp" CREATION DATE 2018-10-21 SIGNATURE CODE_20181021_LUOGU3369 COMMENT P3369 【模板】普通平衡树 / Treap */ #include <cstdlib> #include <iostream> #include <algorithm> #include <ctime> #include <cstdio> using namespace std; const int MaxN = 100000 + 10; template<typename IntType> void read(IntType& val) { val = 0; int c; bool inv = false; while (!isdigit(c = getchar())) if (c == '-') inv = true; do { val = (val << 1) + (val << 3) + c - '0'; } while (isdigit(c = getchar())); if (inv) val = -val; } struct node { int val, key; int cnt, size; node* lson, *rson; node* father; }; node* root; node mem[MaxN], *memtop = mem; #define ALLOCATE (++memtop) typedef int sontype; const sontype lson = 0, rson = 1; sontype tell(node* son) { return (!son->father || son->father->lson == son) ? lson : rson; } node*& get(node* father, sontype type) { return type == lson ? father->lson : father->rson; } void connect(node* father, node* son, sontype type) { if (father) get(father, type) = son; if (son) son->father = father; } int size(node* p) { return p ? p->size : 0; } void update(node* p) { p->size = size(p->lson) + size(p->rson) + p->cnt; } void rotate(node* p) { if (!p->father) return; sontype t = tell(p); node* f = p->father, *b = get(p, 1 - t); connect(f->father, p, tell(f)); connect(p, f, 1 - t); connect(f, b, t); update(f); update(p); if (!p->father) root = p; } void adjust(node* p) { while (p->father&&p->father->key > p->key) rotate(p); } void insert(int val, node*& p = root, node* from = 0) { if (!p) { p = ALLOCATE; p->father = from; p->val = val; p->key = rand(); p->cnt = p->size = 1; if (from&&p->key < from->key) adjust(p); } else { p->size++; if (val < p->val) insert(val, p->lson, p); else if (val == p->val) p->cnt++; else /* val > p->val */ insert(val, p->rson, p); } } void destroy(node* p) { while (p->lson&&p->rson) if (p->lson->val < p->rson->val) rotate(p->lson); else rotate(p->rson); if (p == root) root = (p->lson ? p->lson : p->rson); connect(p->father, (p->lson ? p->lson : p->rson), tell(p)); } void erase(int val, node* p = root) { p->size--; if (val < p->val) erase(val, p->lson); else if (val == p->val) { p->cnt--; if (!p->cnt) destroy(p); } else /* val > p->val */ erase(val, p->rson); } int getrank(int val, node* p = root) { if (!p) return 1; else if (val < p->val) return getrank(val, p->lson); else if (val == p->val) return size(p->lson) + 1; else /* val > p->val */ return getrank(val, p->rson) + size(p->lson) + p->cnt; } int getkth(int k, node* p = root) { if (k <= size(p->lson)) return getkth(k, p->lson); else if (k <= size(p->lson) + p->cnt) return p->val; else return getkth(k - size(p->lson) - p->cnt, p->rson); } int getprev(int val, node* p = root) { int ans = -1e8; while (p) { if (val <= p->val) p = p->lson; else { /* val > p->val */ ans = max(ans, p->val); p = p->rson; } } return ans; } int getnext(int val, node* p = root) { int ans = 1e8; while (p) { if (val < p->val) { ans = min(ans, p->val); p = p->lson; } else /* val >= p->val */ p = p->rson; } return ans; } int n; int opt, x; int main(int argc, char* argv[]) { srand(time(0)); read(n); for (int i = 1; i <= n; i++) { read(opt); read(x); switch (opt) { case 1: insert(x); break; case 2: erase(x); break; case 3: printf("%d\n", getrank(x)); break; case 4: printf("%d\n", getkth(x)); break; case 5: printf("%d\n", getprev(x)); break; case 6: printf("%d\n", getnext(x)); break; } } return 0; }