P3380

【模板】二逼平衡树(树套树)

我一度以为我的 Splay 因为人傻常数大,直到我使用了 int。。。

真的,long long 太慢了。。。

区间线段树的每个子节点用一个 Splay 来存储信息,动态分配一下内存。

这个题的做法就比较显然,查排名的话直接把线段树区间上的排名累加即可。

查询排名为 $k$ 的数可以先二分答案再用第一个操作。

修改的话把包含这个点的区间修改就好了,和单点修改差不多。

前驱就是所有区间的前驱的最大值。

后继就是所有区间的后继的最小值。

然后注意 Splay 内部的内存需要有一个垃圾回收的方式,这里我开了一个栈。

long long 常数被卡飞,换了个 int 直接单车变摩托了。

时间复杂度 $O(n\log^2 n)$,瓶颈主要在第二个操作。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#include<iostream>
#include<cstdio>
#define ll int
using namespace std;
namespace Ehnaev{
inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<48||ch>57) {if(ch==45) f=-f;ch=getchar();}
while(ch>=48&&ch<=57) {ret=(ret<<3)+(ret<<1)+ch-48;ch=getchar();}
return ret*f;
}
inline void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {do{buf[++len]=x%10+48;x/=10;}while(x);}
else {putchar(45);do{buf[++len]=-(x%10)+48;x/=10;}while(x);}
while(len>=0) putchar(buf[len--]);
}
}using Ehnaev::read;using Ehnaev::write;
inline void writeln(ll x) {write(x);putchar(10);}

const ll N=5e5,M=2e6,inf=(1ll<<31)-1;

ll n,m;
ll a[N+5];
ll buffa[M+5],bufch[2][M+5],bufcnt[M+5],bufval[M+5],bufsiz[M+5],bufst[M+5];
ll *nowfa=buffa,*nowch[2]={bufch[0],bufch[1]},*nowcnt=bufcnt
,*nowval=bufval,*nowsiz=bufsiz,*nowst=bufst;

struct Splay{
ll *fa,*ch[2],*cnt,*val,*siz,*st;ll rt,sz,top;
inline void Init(ll x) {
x+=5;fa=nowfa;ch[0]=nowch[0],ch[1]=nowch[1];cnt=nowcnt;
val=nowval;siz=nowsiz;st=nowst;
nowfa+=x;nowch[0]+=x;nowch[1]+=x;nowcnt+=x;nowval+=x;
nowsiz+=x;nowst+=x;
for(ll i=1;i<=x-2;i++) st[++top]=i;
}
inline ll Get(ll x) {return x==ch[1][fa[x]];}
inline void Pushup(ll x) {siz[x]=siz[ch[0][x]]+siz[ch[1][x]]+cnt[x];}
inline void Clear(ll x) {
fa[x]=ch[0][x]=ch[1][x]=cnt[x]=val[x]=siz[x]=0;st[++top]=x;
}
inline void Rotate(ll x) {
ll y=fa[x],z=fa[y],c=Get(x);ch[c][y]=ch[c^1][x];
if(ch[c^1][x]) fa[ch[c^1][x]]=y;ch[c^1][x]=y;fa[y]=x;fa[x]=z;
if(z) ch[y==ch[1][z]][z]=x;Pushup(y);Pushup(x);
}
inline void splay(ll x,ll g) {
if(!x) return;
for(ll f=fa[x];f=fa[x],f!=g;Rotate(x)) {
if(fa[f]!=g) Rotate(Get(f)==Get(x)?f:x);
}
if(!g) rt=x;
}
inline ll Find(ll k) {
// printf("Find break 1\n");
if(!rt) return 0;ll cur=rt;
// printf("Find break 2\n");
while(ch[k>val[cur]][cur]&&k!=val[cur]) {
cur=ch[k>val[cur]][cur];
// printf("cur=%lld k=%lld ch=%lld\n",cur,k,ch[k>val[cur]][cur]);
}
// printf("Find break 3\n");
splay(cur,0);
// printf("Find break 4\n");
return cur;
}
inline void Ins(ll k) {
if(!rt) {rt=sz=st[top--];val[sz]=k;cnt[sz]++;Pushup(sz);return;}
ll cur=rt,f=0;
while(1) {
if(k==val[cur]) {cnt[cur]++;Pushup(cur);splay(cur,0);break;}
f=cur;cur=ch[k>val[cur]][cur];
if(!cur) {
sz=st[top--];val[sz]=k;cnt[sz]++;fa[sz]=f;ch[k>val[f]][f]=sz;
Pushup(sz);Pushup(f);splay(sz,0);break;
}
}
}
inline ll Rk(ll k) {
Find(k);return k>val[rt]?siz[ch[0][rt]]+cnt[rt]+1:siz[ch[0][rt]]+1;
}
inline ll Pre(ll k) {
Find(k);if(k>val[rt]) return rt;ll cur=ch[0][rt];
while(ch[1][cur]) cur=ch[1][cur];splay(cur,0);return cur;
}
inline ll Nxt(ll k) {
Find(k);if(k<val[rt]) return rt;ll cur=ch[1][rt];
while(ch[0][cur]) cur=ch[0][cur];splay(cur,0);return cur;
}
inline void Del(ll k) {
// printf("break 1\n");
Find(k);ll cur=rt;
// printf("break 2\n");
if(cnt[rt]>1) {cnt[rt]--;Pushup(rt);return;}
// printf("break 3\n");
if(!ch[0][rt]&&!ch[1][rt]) {Clear(rt);rt=0;return;}
// printf("break 4\n");
if(!ch[0][rt]) {rt=ch[1][rt];fa[rt]=0;Clear(cur);return;}
// printf("break 5\n");
if(!ch[1][rt]) {rt=ch[0][rt];fa[rt]=0;Clear(cur);return;}
// printf("break 6\n");
ll x=Pre(k);fa[ch[1][cur]]=x;ch[1][x]=ch[1][cur];Clear(cur);Pushup(x);
// printf("break 7\n");
}
}s[N*4+5];

inline void Build(ll p,ll l,ll r) {
s[p].Init(r-l+1);for(ll i=l;i<=r;i++) {s[p].Ins(a[i]);}
if(l==r) {return;}ll mid=(l+r)>>1;
Build(p<<1,l,mid);Build(p<<1|1,mid+1,r);
}

inline ll Askrk(ll p,ll lp,ll rp,ll l,ll r,ll k) {
if(lp>=l&&rp<=r) {return s[p].Rk(k)-1;}
ll mid=(lp+rp)>>1;
if(l>mid) return Askrk(p<<1|1,mid+1,rp,l,r,k);
if(r<=mid) return Askrk(p<<1,lp,mid,l,r,k);
return Askrk(p<<1,lp,mid,l,r,k)+Askrk(p<<1|1,mid+1,rp,l,r,k);
}

inline ll Askkth(ll l,ll r,ll k) {
ll l_=0,r_=1e8,res=0;
while(l_<=r_) {
ll mid=(l_+r_)>>1;
ll tmp=Askrk(1,1,n,l,r,mid)+1;
if(tmp>k) r_=mid-1;else {l_=mid+1;res=mid;}
}
return res;
}

inline void Modify(ll p,ll lp,ll rp,ll pos,ll k) {
// printf("bef?\n");
s[p].Del(a[pos]);
// printf("Mid?\n");
s[p].Ins(k);
// printf("Aft?\n");
if(lp==rp) {return;}ll mid=(lp+rp)>>1;
if(pos<=mid) Modify(p<<1,lp,mid,pos,k);
if(pos>mid) Modify(p<<1|1,mid+1,rp,pos,k);
}

inline ll Askpre(ll p,ll lp,ll rp,ll l,ll r,ll k) {
if(lp>=l&&rp<=r) {
ll tmp=s[p].Pre(k);if(tmp==0) return -inf;return s[p].val[tmp];
}
ll mid=(lp+rp)>>1,res=-inf;
if(l<=mid) res=max(res,Askpre(p<<1,lp,mid,l,r,k));
if(r>mid) res=max(res,Askpre(p<<1|1,mid+1,rp,l,r,k));
return res;
}

inline ll Asknxt(ll p,ll lp,ll rp,ll l,ll r,ll k) {
if(lp>=l&&rp<=r) {
// printf("s[%lld].rt=%lld\n",p,s[p].rt);
// printf("lp=%lld rp=%lld\n",lp,rp);
ll tmp=s[p].Nxt(k);if(tmp==0) return inf;return s[p].val[tmp];
}
ll mid=(lp+rp)>>1,res=inf;
if(l<=mid) res=min(res,Asknxt(p<<1,lp,mid,l,r,k));
if(r>mid) res=min(res,Asknxt(p<<1|1,mid+1,rp,l,r,k));
return res;
}

int main() {
// freopen("input3.in","r",stdin);
// freopen("w.out","w",stdout);

n=read();m=read();
for(ll i=1;i<=n;i++) {a[i]=read();}
Build(1,1,n);

while(m--) {
// printf("Here?\n");
// printf("s[5].rt=%lld\n",s[5].rt);
ll op,x,y,z;op=read();x=read();y=read();
// printf("op=%lld x=%lld y=%lld\n",op,x,y);
if(op==1) {z=read();writeln(Askrk(1,1,n,x,y,z)+1);}
if(op==2) {z=read();writeln(Askkth(x,y,z));}
if(op==3) {Modify(1,1,n,x,y);a[x]=y;}
if(op==4) {z=read();writeln(Askpre(1,1,n,x,y,z));}
if(op==5) {z=read();writeln(Asknxt(1,1,n,x,y,z));}
}

// fclose(stdin);
// fclose(stdout);
return 0;
}