Diameter Cuts
把这个当作长链剖分的基础题讲了。
首先要有 DP。(下文所著 $K$ 为题目所述 $k$,下文的 $k$ 是普通自变量)
我们对于一个点,可以分类讨论:
- 不连上这个点与其儿子的边,则有转移:
$$f(p,j)=f(p,j)\sum_{k=0}^K f(v,k)$$
- 连上这个点与其儿子的边,则有转移:
$$f(p,\max{j,k+1})=f(p,\max{j,k+1})+f(p,j)\sum_{k=0}^{K-j-1}f(v,k)$$
上述的转移显然是独立出来的,而实际情况需要考虑其间的相互影响,所以用一个临时的变量 $g$ 存贮结果,在最后再传给 $f$。
那么问题来了,这个转移的复杂度是 $O(nk^2)$ 的,怎么使用长链剖分优化?
实际上,就是在满足题目计数的限制的同时,再满足长剖的计数限制。
这样复杂度就可以优化到 $O(nk)$ 了,空间复杂度是 $O(n)$ 的。
有一些细节,具体看代码。
在做完这题之后,我们对长链剖分优化的本质有了更为深入的了解。
实际上,长剖所作的就是让一条长链上的点重复利用同一段状态计数来节省空间,再让每次计数仅发生在长链顶端来节省时间,从而达到了优化的效果。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
| #include<iostream> #include<cstdio> #define ll long long using namespace std;
const ll N=5e3,mo=998244353;
ll n,K,u,v,tot,ans;
ll buf[N*10+5];
ll *f[N+5],*now=buf;
ll g[N+5],dep[N+5],ls[N+5];
ll ver[N*2+5],nxt[N*2+5],head[N+5];
inline void dfs(ll p,ll fath) { for(ll i=head[p];i;i=nxt[i]) { if(ver[i]==fath) continue; dfs(ver[i],p); if(dep[ver[i]]>dep[ls[p]]) ls[p]=ver[i]; } dep[p]=dep[ls[p]]+1; }
inline void dfs_(ll p,ll fath) { if(ls[p]) { f[ls[p]]=f[p]+1; dfs_(ls[p],p); for(ll j=0;j<=K&&j<dep[ls[p]];j++) { f[p][0]+=f[ls[p]][j]; if(f[p][0]>mo) f[p][0]%=mo; } } else f[p][0]=1; for(ll i=head[p];i;i=nxt[i]) { if(ver[i]==fath||ver[i]==ls[p]) continue; f[ver[i]]=now;now+=dep[ver[i]]+3; dfs_(ver[i],p); for(ll j=0;j<=K&&j<dep[p];j++) g[j]=0; for(ll j=0;j<=K&&j<dep[p];j++) { for(ll k=0;k<dep[ver[i]]&&j+k+1<=K;k++) { g[max(j,k+1)]+=f[p][j]*f[ver[i]][k]; if(g[max(j,k+1)]>mo) g[max(j,k+1)]%=mo; } } for(ll j=0;j<=K&&j<dep[p];j++) { for(ll k=0;k<=K&&k<dep[ver[i]];k++) { g[j]+=f[p][j]*f[ver[i]][k]; if(g[j]>mo) g[j]%=mo; } } for(ll j=0;j<=K&&j<dep[p];j++) f[p][j]=g[j]; } }
inline void addedge(ll u,ll v) { ver[++tot]=v;nxt[tot]=head[u];head[u]=tot; }
inline ll read() { ll ret=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();} while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();} return ret*f; }
inline void write(ll x) { static char buf[22];static ll len=-1; if(x>=0) { do{buf[++len]=x%10+48;x/=10;}while(x); } else { putchar('-'); do{buf[++len]=-(x%10)+48;x/=10;}while(x); } while(len>=0) putchar(buf[len--]); }
int main() {
n=read();K=read();
for(ll i=1;i<n;i++) { u=read();v=read(); addedge(u,v);addedge(v,u); }
dfs(1,0); f[1]=now;now+=dep[1]+3; dfs_(1,0);
for(ll i=0;i<=K&&i<dep[1];i++) { ans+=f[1][i];if(ans>mo) ans%=mo; } write(ans);
return 0; }
|