CF997C

Sky Full of Stars

定义 f(a,b) 表示钦定 ab 列颜色相同的方案数。

我们发现 a,b 是否非零对答案是有影响的,不妨分开讨论。

  1. a,b0,则:

    f(a,b)=3(na)(nb)3(na)(nb)

    因为行与列有交点,导致被钦定的行和列颜色必定相同,所以先确定这个颜色,再取确定哪 a 行哪 b 列,然后再随意确定剩余格子的涂色。

  2. ab=0,a+b0,此时我们只有行的颜色相同,或只有列的颜色相同,不妨以行的颜色相同考虑,即 a0,则:

    f(a,0)=(na)3a3n(na)

    对于 b0 的情况是同理的。

  3. a=b=0,则我们所有的格子都是任意涂色的:

    f(0,0)=3n2

接下来我们令 g(a,b) 表示恰好有 ab 列颜色相同的方案数,根据组合意义:

f(a,b)=i=anj=bn(ia)(jb)g(i,j)

然后我们考虑二项式反演:

g(a,b)=i=anj=bn(1)ia(ia)(1)jb(jb)f(i,j)

介于我们最后的答案是:

Ans=3n2g(0,0)

只要计算 g(0,0) 就好了:

g(0,0)=i=0nj=0n(1)i+jf(i,j)

我们有了 O(n2) 的计算方法,但显然还需要优化。


考虑化简式子。

因为我们的 f 是分类讨论得到的,为了方便化简,仍然分类讨论:

  1. a,b0,则:

    i=1nj=1n(1)i+j3(ni)(nj)3(ni)(nj)=3n2+1i=1n(ni)(1)i3inj=1n(nj)(1)j3jn3ij=3n2+1i=1n(ni)(1)i3inj=1n(nj)(1)j3j(in)=3n2+1i=1n(ni)(1)i3inj=1n(nj)1nj(3in)j=3n2+1i=1n(ni)(1)i3in((13in)n1)

    于是这一部分快速幂可以 O(nlogn) 计算了。

  2. ab=0,a+b0,因为两维的情况相同,我们只要算出某一维的贡献然后翻倍即可。

    2i=1n(1)i(ni)3i3n(ni)

    快速幂暴力上就完了,时间复杂度 O(nlogn)

  3. a=b=0,这一部分贡献就是 3n2

    介于我们的答案就是 3n2g(0,0),这一部分就被抵消了。

    所以最后我们的答案就是上面算出的两部分之和的相反数。

最后总的时间复杂度是 O(nlogn) 的。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
namespace Ehnaev{
inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<48||ch>57) {if(ch==45) f=-f;ch=getchar();}
while(ch>=48&&ch<=57) {ret=(ret<<3)+(ret<<1)+ch-48;ch=getchar();}
return ret*f;
}
inline void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {do{buf[++len]=x%10+48;x/=10;}while(x);}
else {putchar(45);do{buf[++len]=-(x%10)+48;x/=10;}while(x);}
while(len>=0) putchar(buf[len--]);
}
}using Ehnaev::read;using Ehnaev::write;

const ll N=2e6,mo=998244353;

ll n;
ll pb3[N+5],pp3[N+5],fac[N+5],invfac[N+5];

inline ll Pow_(ll b,ll p) {
ll r=1;while(p) {if(p&1) r=r*b%mo;b=b*b%mo;p>>=1;}return r;
}

inline void Init() {
pb3[0]=pp3[0]=1;
for(ll i=1;i<=N;i++) pp3[i]=pp3[i-1]*3%mo;
for(ll i=1;i<=N;i++) pb3[i]=pb3[i-1]*pp3[N]%mo;
fac[0]=1;for(ll i=1;i<=n;i++) fac[i]=fac[i-1]*i%mo;
invfac[0]=1;invfac[n]=Pow_(fac[n],mo-2);
for(ll i=n-1;i>=1;i--) invfac[i]=invfac[i+1]*(i+1)%mo;
}

inline ll C(ll n,ll m) {return (fac[n]*invfac[m]%mo)*invfac[n-m]%mo;}
inline ll Pow(ll p) {
if(p<0) p=p+(-p/(mo-1)+1)*(mo-1);
if(p>=mo-1) p=p%(mo-1);
return pb3[p/N]*pp3[p%N]%mo;
}

int main() {

n=read();Init();

ll tmp1=0;
for(ll i=1;i<=n;i++) {
ll tmp=C(n,i);tmp=tmp*Pow(-i*n)%mo;
ll tmpp=(mo-Pow(i-n))%mo;tmpp=(tmpp+1)%mo;
tmpp=Pow_(tmpp,n)%mo;tmpp=(tmpp-1+mo)%mo;
tmp=tmp*tmpp%mo;
if(i&1) {tmp1=(tmp1-tmp+mo)%mo;}
else {tmp1=(tmp1+tmp)%mo;}
}
tmp1=tmp1*Pow(n*n+1)%mo;

ll tmp2=0;
for(ll i=1;i<=n;i++) {
ll tmp=C(n,i);tmp=tmp*Pow(i+n*(n-i))%mo;
if(i&1) {tmp2=(tmp2-tmp+mo)%mo;}
else {tmp2=(tmp2+tmp)%mo;}
}
tmp2=tmp2*2%mo;

ll ans=0;ans=(ans-tmp2+mo)%mo;
ans=(ans-tmp1+mo)%mo;

write(ans);

return 0;
}