CF600E

Lomsat gelral

其实挺好理解的。理解了树剖的复杂度原理之后,这个东西的原理其实非常相似。

同样,我们第一次 dfs 寻找重儿子。

然后我们开始统计。

统计的顺序是这样的:

  1. 求解 $p$ 点轻儿子子树内的答案。但是不保留它对桶 $cnt$ 的贡献。

  2. 求解 $p$ 的重儿子子树内的答案,并且保留(暂时)它对桶 $cnt$ 的贡献。

  3. 再遍历 $p$ 的轻儿子子树,相当于暴力统计出 $p$ 的答案,然后 $p$ 的答案就算出来了。

  4. 如果需要保留现有的桶和答案什么的就保留;反之,我们就把之前统计的所有桶内的答案撤销。

这个原理我们可以从每一个点的被数次数想。

显然一个点被统计的次数与其到根节点路径上的轻边数基本线性相关,或者,更通俗地说,与重链数基本线性相关。而根据重链剖分的性质我们容易知道重链数是 $O(\log n)$ 的。

显然求解的递归过程是 $O(n)$(除去计数操作)。

那么意味着每个点都会被基础性的操作一次。

那么除去基础性的操作,剩下的都是附加操作了。

一个点被附加操作数到,当且仅当被作为轻儿子子树内的一个点。

于是乎就和轻边数线性相关了。

所以时间复杂度是 $O(n\log n)$ 的。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=1e5;

ll n,u,v,tmp_col,tot,macnt,mc,tc;

ll ver[N*2+5],nxt[N*2+5],head[N+5];

ll main_col[N+5],siz[N+5],hs[N+5],cnt[N+5],c[N+5];

inline void dfs(ll p,ll fath) {
siz[p]=1;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
dfs(ver[i],p);
if(siz[ver[i]]>siz[hs[p]]) hs[p]=ver[i];
siz[p]+=siz[ver[i]];
}
}

inline void add(ll p,ll fath,ll k) {
cnt[c[p]]+=k;
if(cnt[c[p]]>mc) {mc=cnt[c[p]];tc=c[p];}
else {if(cnt[c[p]]==mc) tc+=c[p];}
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
add(ver[i],p,k);
}
}

inline void dfs_(ll p,ll fath,bool kp) {
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath||ver[i]==hs[p]) continue;
dfs_(ver[i],p,0);
}
if(hs[p]) dfs_(hs[p],p,1);
mc=macnt;tc=tmp_col;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath||ver[i]==hs[p]) continue;
add(ver[i],p,1);
}
cnt[c[p]]++;
if(cnt[c[p]]>mc) {mc=cnt[c[p]];tc=c[p];}
else {if(cnt[c[p]]==mc) tc+=c[p];}
main_col[p]=tc;
if(kp) {macnt=mc;tmp_col=tc;}
if(!kp) {
cnt[c[p]]--;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
add(ver[i],p,-1);
}
macnt=0;tmp_col=0;
}
}

inline void addedge(ll u,ll v) {
ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
return ret*f;
}

inline void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {
do{buf[++len]=x%10+48;x/=10;}while(x);
}
else {
putchar('-');
do{buf[++len]=-(x%10)+48;x/=10;}while(x);
}
while(len>=0) putchar(buf[len--]);
}

int main() {

n=read();

for(ll i=1;i<=n;i++) {
c[i]=read();
}

for(ll i=1;i<n;i++) {
u=read();v=read();
addedge(u,v);addedge(v,u);
}

dfs(1,0);

dfs_(1,0,0);

for(ll i=1;i<=n;i++) {
write(main_col[i]);putchar(' ');
}

return 0;
}