P4781

【模板】拉格朗日插值

实际上拉格朗日插值是一种构造,所以我们先从构造的第一步开始。

我们构造 $f_i(x)$ 使得:

$$\begin{cases}f_i(x_i)=y_i \\ f_i(x_j)=0 & j \not=i\end{cases}$$

接下来就是一个很 whk 数学的构造(交点式),设:

$$f_i(x)=a\prod_{j\not=i}(x-x_j)$$

这样就满足了 $\forall j\not=i$,$f_i(x_j)=0$。

接下来我们代入 $x=x_i$,就可以得到:

$$y_i=a\prod_{j\not=i}(x_i-x_j)$$

得到了 $a$ 的表达式:

$$a=\dfrac{y_i}{\prod_{j\not=i}(x_i-x_j)}$$

再把这个表达式往一开始的交点式里代:

$$f_i(x)=y_i\prod_{j\not=i}\dfrac{x-x_j}{x_i-x_j}$$

至于为什么要这么构造:

$$f(x)=\sum_{i=1}^nf_i(x)$$

所以就是:

$$f(x)=\sum_{i=1}^ny_i\prod_{j\not=i}\dfrac{x-x_j}{x_i-x_j}$$

我们就得到了拉格朗日插值的式子。

这个题就不用怎么变形了,直接 $O(n^2\log mo)$ 就能过了。

关于 CRT 的推导,我们所有的同余关系都是建立在多项式上的。

需要摒除传统的数字观念来看多项式的同余。

下面基本是复读 OI wiki 的证明。

因为:

$$f(x)-f(a)=(f_0-f_0)+f_1(x-a)+f_2(x^2-a^2)+\cdots+f_n(x^n-a^n)$$

所以 $f(x)-f(a)$ 必然有因式 $x-a$(可以对这个东西做多项式除法)。

然后我们就能得到:

$$f(x)-f(a)\equiv 0\pmod{x-a}$$

其实就是:

$$f(x)\equiv f(a)\pmod{x-a}$$

然后我们 $a$ 取遍 $x_i$ 组成线性方程组:

$$\begin{cases}f(x)\equiv f(x_1)\equiv y_1\pmod{(x-x_1)} \\ f(x)\equiv f(x_2)\equiv y_2 \pmod{(x-x_2)} \\ \cdots \\ f(x)\equiv f(x_n)\equiv y_n\pmod{(x-x_n)}\end{cases}$$

你会发现 $x-x_i$ 在多项式的意义下是两两互质的。

所以可以在多项式的意义下使用中国剩余定理。

于是我们就有了:

$$m=\prod_{i=1}^n(x-x_i)$$

$$M_i=\dfrac{m}{x-x_i}=\prod_{j\not=i}(x-x_j)$$

$$t_i=M_i^{-1}=\prod_{j\not=i}\dfrac{1}{x-x_j}$$

显然这里的 $t_i$ 形式不太好,我们根据它是在 $\pmod{(x-x_i)}$ 下 $M_i$ 的逆元来对它进行变换。

模意义下 $x-x_j$ 与 $x-x_j-(x-x_i)$ 是相同的,所以直接换掉:

$$t_i=\prod_{j\not=i}\dfrac{1}{x_i-x_j}$$

于是我们的解就是:

$$f(x)=\sum_{i=1}^ny_iM_it_i=\sum_{i=1}^n y_i\prod_{j\not=i}\dfrac{x-x_j}{x_i-x_j}$$

这个推导是多项式模意义下的。但反正也是适用的。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=2e3,mo=998244353;

ll n,k,ans;

ll x[N+5],y[N+5];

inline ll qpow(ll b,ll p) {
ll res=1;while(p){if(p&1) res=res*b%mo;b=b*b%mo;p>>=1;}return res;
}
inline ll inv(ll x) {return qpow(x,mo-2);}

inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<48||ch>57) {if(ch==45) f=-f;ch=getchar();}
while(ch>=48&&ch<=57) {ret=(ret<<3)+(ret<<1)+ch-48;ch=getchar();}
return ret*f;
}
inline void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {do{buf[++len]=x%10+48;x/=10;}while(x);}
else {putchar(45);do{buf[++len]=-(x%10)+48;x/=10;}while(x);}
while(len>=0) putchar(buf[len--]);
}

int main() {

n=read();k=read();

for(ll i=1;i<=n;i++) {x[i]=read();y[i]=read();}

for(ll i=1;i<=n;i++) {
ll tmp=y[i];
for(ll j=1;j<=n;j++) {
if(i==j) continue;
tmp=tmp*((k-x[j]+mo)%mo)%mo;
tmp=tmp*inv((x[i]-x[j]+mo)%mo)%mo;
}
ans=(ans+tmp)%mo;
}

write(ans);

return 0;
}