BZOJ 3224 普通平衡树
splay板子,记录一下
#include <iostream>
#include <cstdio>
using namespace std;
//define
const int maxn=1e5+5;
struct Splay{
int ch[maxn<<1][2] ,par[maxn<<1] ,cnt[maxn<<1] ,size[maxn<<1] ,val[maxn<<1] ,ntot ,trt;
int chk(int rt){return ch[par[rt]][1] == rt;}
void pushup(int rt){size[rt] = size[ch[rt][0]] + size[ch[rt][1]] + cnt[rt];}
void rotate(int rt){
int y = par[rt];
int z = par[y];
int rnk = chk(rt);
int w = ch[rt][rnk^1];
ch[y][rnk] = w; par[w] = y;
ch[z][chk(y)] = rt;par[rt] = z;
ch[rt][rnk^1] = y;par[y] = rt;
pushup(y); pushup(rt);
}
void splay(int rt,int g = 0){
while(par[rt] != g){
int y = par[rt] ,z = par[y];
if(z != g){
if(chk(rt) == chk(y))rotate(y);
else rotate(rt);
}
rotate(rt);
}
if(!g) trt = rt;
}
void find(int rt){
if(!trt) return ;
int cur = trt;
while(ch[cur][rt > val[cur]] && rt != val[cur]){
cur = ch[cur][rt > val[cur]];
}
splay(cur);
}
void insert(int rt){
int cur = trt ,p = 0;
while(cur && val[cur] != rt){
p = cur;
cur = ch[cur][rt > val[cur]];
}
if(cur){
cnt[cur]++;
}
else{
cur = ++ntot;
if(p) ch[p][rt > val[p]] = cur;
ch[cur][0] = ch[cur][1] = 0;
val[cur] = rt;
par[cur] = p;
cnt[cur] = size[cur] = 1;
}
splay(cur);
}
int kth(int k){
int cur = trt;
while(1){
if(ch[cur][0] && k <= size[ch[cur][0]]){
cur = ch[cur][0];
}
else if(k > size[ch[cur][0]] + cnt[cur]){
k -= size[ch[cur][0]] + cnt[cur];
cur = ch[cur][1];
}
else{
return cur;
}
}
}
int pre(int rt){
find(rt);
if(val[trt] < rt)return trt;
int cur = ch[trt][0];
while(ch[cur][1]) cur = ch[cur][1];
return cur;
}
int src(int rt){
find(rt);
if(val[trt] > rt)return trt;
int cur = ch[trt][1];
while(ch[cur][0]) cur = ch[cur][0];
return cur;
}
void remove(int rt){
int last = pre(rt) ,next = src(rt);
splay(last); splay(next ,last);
int del = ch[next][0];
if(cnt[del] > 1){
cnt[del]--;
splay(del);
}
else ch[next][0] = 0;
}
int get_rt(){
return size[ch[trt][0]];
}
int get_kth(int rt){
return val[kth(rt+1)];
}
int get_pre(int rt){
return val[pre(rt)];
}
int get_src(int rt){
return val[src(rt)];
}
}tree;
//main
int main(){
ios::sync_with_stdio(false);
int n,op,x;
cin>>n;
tree.insert(0x3f3f3f3f);
tree.insert(0xcfcfcfcf);
while(n--){
cin>>op>>x;
if(op == 1){
tree.insert(x);
}
if(op == 2){
tree.remove(x);
}
if(op == 3){
tree.find(x);
cout<<tree.get_rt()<<endl;
}
if(op == 4){
cout<<tree.get_kth(x)<<endl;
}
if(op == 5){
cout<<tree.get_pre(x)<<endl;
}
if(op == 6){
cout<<tree.get_src(x)<<endl;
}
}
return 0;
}