P6478

[NOI Online #2 提高组] 游戏

考虑钦定法与二项式反演。

令 $f(u,k)$ 表示在 $u$ 子树内钦定 $k$ 次非平局。

先不考虑根与子树内部点产生的贡献,那么树形 DP 就是加法卷积的叠加:

$$f(u,k)\leftarrow \sum_{i=0}^{k}f(u,i)f(v,k-i)$$

其中 $v$ 是 $u$ 的儿子。

这里看样子是 $O(n^3)$ 的,把转移换成 NTT 可以做到 $O(n^2\log n)$,但仍然跑不过去。

实际上有一个关于这种树上加法卷积背包的复杂度结论,只要贴着线做普通背包复杂度就是 $O(n^2)$。

$$
\begin{aligned}
T(n)=& O\left( \sum _ {u} \sum _ {v = son(u,i)} \left( \sum _ { j=1} ^ {i-1} siz(son(u,j)) \right) siz(v) \right)
\end{aligned}
$$

每个点只会和其他点构成 $O(n)$ 的复杂度贡献(每次子树大小相乘的时候两颗子树不会有重合的部分,而且对于每个点来说,每次乘的都不一样)。

于是复杂度就是 $O(n^2)$ 的了。

冒死使用 NTT,因为常数过大导致速度不如普通背包。。。

然后是计算根与子树内部点的贡献。

设 $s_0(u)$ 表示 $u$ 子树内小 A 拥有的点数,$s_1(u)$ 表示小 B 拥有的点数。不妨设 $u$ 为小 B 拥有的点:

$$f(u,k+1)\leftarrow f(u,k+1)+f(u,k)(s_0(u)-k)$$

显然也可以 $O(n^2)$ 完成计算。

接下来定义 $\omega(i)=f(1,i)(m-i)!$,$\gamma(i)$ 表示整棵树中恰好有 $i$ 次非平局的方案数。

由组合意义可知:

$$\omega(k)=\sum_{i=k}^{m}\binom{i}{k}\gamma(i)$$

再根据二项式反演:

$$\gamma(k)=\sum_{i=k}^{m}(-1)^{i-k}\binom{i}{k}\omega(i)$$

这里也是可以 $O(n^2)$ 计算的。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
#define Clr(f,n) memset(f,0,sizeof(ll)*(n))
#define Cpy(f,g,n) memcpy(f,g,sizeof(ll)*(n))
using namespace std;
namespace Ehnaev{
inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<48||ch>57) {if(ch==45) f=-f;ch=getchar();}
while(ch>=48&&ch<=57) {ret=(ret<<3)+(ret<<1)+ch-48;ch=getchar();}
return ret*f;
}
inline void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {do{buf[++len]=x%10+48;x/=10;}while(x);}
else {putchar(45);do{buf[++len]=-(x%10)+48;x/=10;}while(x);}
while(len>=0) putchar(buf[len--]);
}
}using Ehnaev::read;using Ehnaev::write;
inline void writeln(ll x) {write(x);putchar(10);}

const ll N=1e4,M=5e3,mo=998244353,G=3;

inline ll Pow(ll b,ll p) {
ll r=1;while(p) {if(p&1) r=r*b%mo;b=b*b%mo;p>>=1;}return r;
}

const ll invG=Pow(G,mo-2);

ll n,m,tot,u,v;
ll rev[N+5],h[N+5],fac[N+5],invfac[N+5],g[N+5];
ll f[M+5][N+5];
ll ver[N+5],nxt[N+5],head[N+5];
ll s0[N+5],s1[N+5],len[N+5];
char s[N+5];

inline void NTT_Init(ll n) {
for(ll i=0;i<n;i++) rev[i]=0;
for(ll i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(n>>1):0);
}

inline void NTT(ll *f,ll n,bool op) {
NTT_Init(n);
for(ll i=0;i<n;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(ll p=2;p<=n;p<<=1) {
ll len=p>>1,tG=Pow(op?invG:G,(mo-1)/p);
for(ll k=0;k<n;k+=p) {
for(ll l=k,buf=1;l<k+len;l++,buf=buf*tG%mo) {
ll t=buf*f[l+len]%mo;
f[l+len]=f[l]-t;if(f[l+len]<0) f[l+len]+=mo;
f[l]=f[l]+t;if(f[l]>=mo) f[l]-=mo;
}
}
}
if(op) {
ll invn=Pow(n,mo-2);for(ll i=0;i<n;i++) f[i]=f[i]*invn%mo;
}
}

inline void Times(ll *f,ll *g,ll &lenf,ll lena,ll lenb) {
static ll sav[N+5];
lenf=1;for(;lenf<=lena+lenb;lenf<<=1);
NTT(f,lenf,0);NTT(g,lenf,0);
for(ll i=0;i<lenf;i++) f[i]=f[i]*g[i]%mo;
NTT(f,lenf,1);Clr(sav,lenf);lenf=lena+lenb;
}

inline void DP(ll p,ll fath) {
s0[p]=(s[p]=='0');s1[p]=(s[p]=='1');
f[p][0]=1;len[p]=0;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
DP(ver[i],p);s0[p]+=s0[ver[i]];s1[p]+=s1[ver[i]];
Times(f[p],f[ver[i]],len[p],len[p],len[ver[i]]);
}
if(s[p]=='0') {
for(ll i=s1[p];i>=1;i--) {
f[p][i]=(f[p][i]+f[p][i-1]*(s1[p]-(i-1))%mo)%mo;
if(f[p][i]>0) len[p]=max(len[p],i);
}
}
else {
for(ll i=s0[p];i>=1;i--) {
f[p][i]=(f[p][i]+f[p][i-1]*(s0[p]-(i-1))%mo)%mo;
if(f[p][i]>0) len[p]=max(len[p],i);
}
}
}

inline ll C(ll n,ll m) {return (fac[n]*invfac[m]%mo)*invfac[n-m]%mo;}

inline ll Calc(ll k) {
ll r=0;
for(ll i=k;i<=m;i++) {
if((i-k)&1) {r=(r-C(i,k)*h[i]%mo+mo)%mo;}
else {r=(r+C(i,k)*h[i]%mo)%mo;}
}
return r;
}

inline void Addedge(ll u,ll v) {
ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline void Init() {
fac[0]=1;for(ll i=1;i<=m;i++) fac[i]=fac[i-1]*i%mo;
invfac[0]=1;invfac[m]=Pow(fac[m],mo-2);
for(ll i=m-1;i>=1;i--) invfac[i]=invfac[i+1]*(i+1)%mo;
for(ll i=0;i<=m;i++) h[i]=h[i]*fac[m-i]%mo;
}

int main() {

n=read();m=n>>1;scanf("%s",s+1);
for(ll i=1;i<n;i++) {
u=read();v=read();Addedge(u,v);Addedge(v,u);
}
DP(1,0);Cpy(h,f[1],n);Init();

for(ll i=0;i<=m;i++) {writeln(Calc(i));}

return 0;
}