P3379

【模板】最近公共祖先(LCA)

继续划水摸鱼。

总结一下 LCA 的常用求法:

  1. 倍增。$O(n\log n)$ 预处理,$O(\log n)$ 查询。码量适中。

  2. 树剖。$O(n)$ 预处理,$O(\log n)$ 查询。码量小。

  3. Tarjan。离线算法,需要预先知道询问。$O(n+q)$ 预处理,$O(1)$ 查询。码量小。

  4. 四毛子。$O(n\log n)$ 预处理,$O(1)$ 查询。码量稍大。

首先是第一种。原理是存储 $x$ 的 $2^i$ 级父亲 $fa(x,i)$,然后往上跳。优点是比较好写,速度还好,可以在线。缺点是还要再多开一些内存。

第二种树剖是我比较喜欢的写法。树剖实际上非常灵活,不少树上问题都可以用树剖来解决。实际上原理和第一种很类似,我们是跳重链来保证复杂度。这个算法的空间也是 $O(n)$ 的,预处理也不会有瓶颈,码量非常小,记忆很方便。

第三种的做法依托于 DFS,借助并查集来寻找 LCA。优点是比较好写,而且重复访问复杂度 $O(1)$。缺点是只能离线。

第四种做法依托于欧拉序和 ST 表。原理就是把树的欧拉序搞出来,每个位置存放其深度,查询两点间欧拉序中深度最小的点就是 LCA。优点是查询快,可以在线。缺点是码量稍大,处理起来有一点点麻烦。

这里没有第四种的代码,因为我比较懒

代码(倍增):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=5e5;

ll n,m,s,u,v,tot;

ll lg[N+5],ver[(N<<1)+5],nxt[(N<<1)+5],head[N+5],dt[N+5],fa[N+5][25];

void dfs(ll p,ll fath) {
fa[p][0]=fath;dt[p]=dt[fath]+1;
for(ll i=1;i<=lg[dt[p]];i++) {
fa[p][i]=fa[fa[p][i-1]][i-1];
}
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]!=fath) dfs(ver[i],p);
}
}

void init() {
for(ll i=1;i<=n;i++) {
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
}
dfs(s,0);
}

ll lca(ll a,ll b) {
if(dt[a]<dt[b]) swap(a,b);
while(dt[a]>dt[b]) a=fa[a][lg[dt[a]-dt[b]]-1];
if(a==b) return a;
for(ll k=lg[dt[a]]-1;k>=0;k--) {
if(fa[a][k]!=fa[b][k]) {
a=fa[a][k];b=fa[b][k];
}
}
return fa[a][0];
}

void add(ll u,ll v) {
ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
return ret*f;
}

void write(ll x) {
if(x<0) {x=-x;putchar('-');}
if(x>9) write(x/10);
putchar(x%10+48);
}

int main() {

n=read();m=read();s=read();

for(ll i=1;i<n;i++) {
u=read();v=read();add(u,v);add(v,u);
}

init();

for(ll i=1;i<=m;i++) {
u=read();v=read();write(lca(u,v));putchar('\n');
}

return 0;
}

代码(树剖):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;

const ll N=5e5;

ll n,m,rt,u,v,tot;

ll ver[N*2+5],nxt[N*2+5],head[N+5];

ll siz[N+5],dt[N+5],fa[N+5],hs[N+5],top[N+5];

inline void dfs(ll p,ll fath) {
siz[p]=1;dt[p]=dt[fath]+1;fa[p]=fath;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
dfs(ver[i],p);
if(siz[ver[i]]>siz[hs[p]]) hs[p]=ver[i];
siz[p]+=siz[ver[i]];
}
}

inline void dfs_(ll p,ll fath,ll topf) {
top[p]=topf;
if(hs[p]) {
dfs_(hs[p],p,topf);
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath||ver[i]==hs[p]) continue;
dfs_(ver[i],p,ver[i]);
}
}
}

inline ll getlca(ll x,ll y) {
while(top[x]!=top[y]) {
if(dt[top[x]]<dt[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dt[x]<dt[y]) return x;
return y;
}

inline void addedge(ll u,ll v) {
ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
return ret*f;
}

inline void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {
do{buf[++len]=x%10+48;x/=10;}while(x);
}
else {
putchar('-');
do{buf[++len]=-(x%10)+48;x/=10;}while(x);
}
while(len>=0) putchar(buf[len--]);
}

inline void writeln(ll x) {
write(x);putchar('\n');
}

int main() {

n=read();m=read();rt=read();

for(ll i=1;i<n;i++) {
u=read();v=read();
addedge(u,v);addedge(v,u);
}

dfs(rt,0);dfs_(rt,0,rt);

while(m--) {
u=read();v=read();
writeln(getlca(u,v));
}

return 0;
}

代码(Tarjan):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;

const ll N=5e5;

ll n,m,rt,u,v,tot,tq;

ll fa[N+5],ans[N+5],a[N+5],b[N+5];

ll ver[N*2+5],nxt[N*2+5],head[N+5];

ll vq[N*2+5],nq[N*2+5],hq[N*2+5];

bool vis[N+5];

inline ll find(ll x) {
if(x==fa[x]) return x;
return fa[x]=find(fa[x]);
}

inline void uni(ll x,ll y) {
fa[find(x)]=find(y);
}

inline void dfs(ll p,ll fath) {
vis[p]=1;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
dfs(ver[i],p);uni(ver[i],p);
}
for(ll i=hq[p];i;i=nq[i]) {
if(vis[vq[i]]) {ans[i/2]=find(vq[i]);}
}
}

inline void addedge(ll u,ll v) {
ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline void addquery(ll u,ll v) {
vq[++tq]=v;nq[tq]=hq[u];hq[u]=tq;
}

inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
return ret*f;
}

inline void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {
do{buf[++len]=x%10+48;x/=10;}while(x);
}
else {
putchar('-');
do{buf[++len]=-(x%10)+48;x/=10;}while(x);
}
while(len>=0) putchar(buf[len--]);
}

inline void writeln(ll x) {
write(x);putchar('\n');
}

int main() {

n=read();m=read();rt=read();

for(ll i=1;i<n;i++) {
u=read();v=read();
addedge(u,v);addedge(v,u);
}

for(ll i=1;i<=n;i++) fa[i]=i;

tq=1;
for(ll i=1;i<=m;i++) {
a[i]=read();b[i]=read();
addquery(a[i],b[i]);addquery(b[i],a[i]);
}

dfs(rt,0);

for(ll i=1;i<=m;i++) {
writeln(ans[i]);
}

return 0;
}