ei1333の日記

ぺこい

ラグランジュ補間

ラグランジュ補間ってなーんだ?

算数エアプが書きます 多分算数エアプでもわかる 誤りには気をつけているつもりですが、エアプ特有のミスがあったらすみません

最初に

 {p}素数とします。 基本的には  {\bmod p} 上での演算を対象とした議論です。加減乗算を  {O(1)}、除算を  {O(\log p)} で行えることを仮定しています。

添字は 0-indexed です。

ラグランジュ補間多項式

 {(x_i, y_i)} の組が  {k + 1} 個あります。ここで  {x_i} は相異なります。

このラグランジュ補間多項式  {f}は、すべての  {i (0 \leq i \leq k)} に対して  {f(x_i) = y_i} を満たす、次数が最小の多項式を指します。

証明は知らんのでしませんが、多項式の性質として、次数が最小の多項式は与えられた組に対して一意に定まります。 つまりラグランジュ補間多項式は一意です。(係数が体なら、対応する次数  {k} 以下の多項式が一意に存在するため、次数  {k} 以下の最小次数の多項式も一意)

ラグランジュ補間多項式の定義

ラグランジュ基底多項式  {l_i(x)} を (1) に定めます。

 {\displaystyle l_i(x) = \prod_{ j \neq i } \frac {x - x_j} {x_i - x_j} =  \frac {\prod_{ j \neq i } (x - x_j)} {\prod _{ j \neq i } (x_i - x_j)} \tag{1}}

 {i = j} ならば  {l_i(x_j) = 1} {i \neq j} ならば  {l_i(x_j) = 0} を返す多項式であることを確認できます。( {i=j} なら分子分母同じなので  {1} {i \neq j} ならば分子の積のうちどこかが  {0} になるので  {0} です。)

このラグランジュ基底多項式を用いて、ラグランジュ補間多項式  {f(x)} を (2) で表せます。

 {\displaystyle f(x) = \sum_{i} y_i l_i (x) \tag{2}}

 {i=j} のときに  {f(x_j) = y_i} となることは  {l_i (x)} の性質から明らかです。

 {l_i(x)} は高々  {k} 次の多項式なので、当然その sum も高々  {k} 次です。したがって、与えられた条件を全て満たす最小次数の多項式を補間できたことになります。

多項式の係数

ラグランジュ補間多項式を復元するためには、上の定義の基づいて  {x^p (0 \leq p \leq k)} の係数を計算していけば良いです。

 {\{f(x_0), f(x_1), \cdots, f(x_k)  \} \Rightarrow f } の各係数

かんたん

 {O(k^3)}アルゴリズムです。上の定義をそのまま実装するとできます。

まず、ラグランジュ基底多項式を求めます。  {k + 1} 本の多項式をそれぞれ求めることとして、いま  {i} 番目の多項式を求めたいとします。この多項式の分母部分を  {w_i}、分子部分を {m_i} とします。それぞれ以下のように計算されます。

 { \displaystyle w_i = \prod_{ j \neq i } (x_i - x_j) \tag{3}}

 { \displaystyle m_i = \prod_{ j \neq i } (x - x_j) \tag{4}}

分子  {w_i} はねねちゃんをすると  {O(k)} で求められます。

分母  {m_i} を求める際は  {0 \leq p \leq k} {p} について  {x^p} の係数を求めます。 各  {j} について  {x} を取るとき次数が  {1} 増えて  {-x_j} を取る時次数が変わらないとして、今までの  {x} の次数を状態に持った  {O(k^2)} のDP により求められます。

具体的には、 {dp(i, j)} {i} 番目まで見たときの  {x^j} の次数とします。遷移は (5) です。

 {dp(i, j) = dp(i - 1, j - 1) + dp(i - 1, j) \times (-x_{i}) \tag{5}}

最後に求めたラグランジュ基底多項式の和を求めます。 足すだけなので  {O(k^2)} です。

template< typename Mint >
vector< Mint > lagrange_polynomial(const vector< Mint > &x, const vector< Mint > &y) {
  int k = (int) x.size() - 1;
  vector< Mint > f(k + 1);
  for(int i = 0; i <= k; i++) {
    Mint d = 1;
    for(int j = 0; j <= k; j++) {
      if(i != j) {
        d *= x[i] - x[j];
      }
    }
    vector< Mint > dp(k + 1);
    dp[0] = y[i] / d;
    for(int j = 0; j <= k; j++) {
      if(i != j) {
        for(int l = k; l > 0; l--) {
          dp[l] = dp[l] * -x[j] + dp[l - 1];
        }
        dp[0] *= -x[j];
      }
    }
    for(int j = 0; j <= k; j++) {
      f[j] += dp[j];
    }
  }
  return f;
}
ふつう

 {O(k^2)}アルゴリズムです。一般的な多項式の復元はこのアルゴリズムを用いれば良いです。

かんたん のボトルネックは、ラグランジュ基底多項式の分子を求める部分でした。これを全体で  {O(k^2)} で求まれば良さそうです。

前計算として  {i} を無視して、 { \displaystyle \prod_{j} (x - x_j) } {x^p} の係数を求める DP をします。(4) を整理すると (6) になります。

 { \displaystyle m_i =  \prod_{ j \neq i } (x - x_j) = \displaystyle \prod_{j} (x - x_j) \times \frac {1} {x - x_i} \tag{6}}

したがって  {(x - x_i)} を取り除いた結果が求まればよいです。

多項式除算をすると log がかかりますが、戻すDPをすればさらに効率的に求めることが可能です。DP で取る順番を変えても計算結果を変わらないことは明らかです。また DP の遷移は (4) のように表されました。

 {dp(i, j) = dp(i - 1, j - 1) + dp(i - 1, j) \times (-x_{i}) \tag{4} }

これを  {dp(i - 1, j)} について解きます。

 {\displaystyle \begin{eqnarray}
dp(i - 1, j) =
  \begin{cases}
    dp(i, j + 1)\ \ ( x_i = 0 ) \\
    \displaystyle \frac {dp(i, j) - dp(i-1,j-1)} {-x_{i}}\ \ ( x_i \neq 0 )
  \end{cases}
\end{eqnarray} \tag{7}}

これを用いることで  {1} 個要素を除いたときの DP を  {O(k)} で復元できて,各ラグランジュ基底多項式 {O(k)} で求められます。

template< typename Mint >
vector< Mint > lagrange_polynomial(const vector< Mint > &x, const vector< Mint > &y) {
  int k = (int) x.size() - 1;

  vector< Mint > f(k + 1), dp(k + 2);
  dp[0] = 1;
  for(int j = 0; j <= k; j++) {
    for(int l = k + 1; l > 0; l--) {
      dp[l] = dp[l] * -x[j] + dp[l - 1];
    }
    dp[0] *= -x[j];
  }

  for(int i = 0; i <= k; i++) {
    Mint d = 1;
    for(int j = 0; j <= k; j++) {
      if(i != j) {
        d *= x[i] - x[j];
      }
    }
    Mint mul = y[i] / d;
    if(x[i] == 0) {
      for(int j = 0; j <= k; j++) {
        f[j] += dp[j + 1] * mul;
      }
    } else {
      Mint inv = Mint(1) / (-x[i]), pre = 0;
      for(int j = 0; j <= k; j++) {
        Mint cur = (dp[j] - pre) * inv;
        f[j] += cur * mul;
        pre = cur;
      }
    }
  }
  return f;
}
むずかしい

 {O(k \log^2 k)}アルゴリズムです。 詳細は他の文献を参照してください。

 { w_i = \displaystyle \prod_{ j \neq i } (x_i - x_j) \tag{3}}

分母  {w_i (0 \leq i \leq k)} の列を効率的に求めるところから始めます。多項式  {l(x)} を (8) のように定義します。

 {\displaystyle l(x) = \prod_{i=0}^{k} (x - x_i) \tag {8} }

 {l(x)} は分割統治で  {O(k \log^2 k)} で求められます。 このとき、がんばると  {w_i = l'(x_i)} が成立することがわかります。 多項式  {l'(x)} について各  {x_i} について評価された値を知りたいので、 multipoint evalution をすればよいです。このアルゴリズムの計算量は  {O(k \log^2 k)} でできることが知られていますが(知られていてくれ)、その解説は他の文献を参照してください(剰余の定理により  {f(x_i)} {f(x) \pmod {x - x_i}} と等しいことが分かるので、多項式  {\prod (x - x_i)} をノードに持たせたセグ木状の subproduct tree を分割統治して作ります。その木を  {f(x) \pmod {\prod (x - x_i) }} を計算しながら潜っていくと、各葉でそれぞれの  {f(x_i)} の値を求めることができます。このときに、多項式のmodがNTTを用いて  {O(n \log n)} で求まることを使うので、形式的冪級数のライブラリがあると良いです。 )

分母の列を  {O(k \log^2 k)} で求められたので、次に分子の列及び  {y_i} を掛けながら、多項式  {f} の係数を求めます。多項式  {f_{l, r}} を (9) のように定義します。

 {\displaystyle f_{l,r} (x) = \sum_{i=l}^{r-1} \frac {y_i} {w_i} \prod_{ l \leq j \lt r, j \neq i } (x - x_j) \tag{9}}

 {l + 1 = r} の場合は、(10) です。

 {\displaystyle f_{l, r} = \frac {y_l} {w_l} \tag{10}}

そうでない場合は、 {f_{l, m} (x), f_{m, r} (x)} から  {f_{l, r} (x)} を求めることを考えます。これは  {f_{l, m} (x)} {\displaystyle \prod_{j=m}^{r-1} (x - x_j) } {f_{m, r} (x)} {\displaystyle \prod_{j=l}^{m-1} (x - x_j) } を掛けた上で足し合わせればよいです。  {\displaystyle m = \frac {l + r} {2}} とすれば分割統治の計算量になって、掛けるところで畳み込みの log がかかるので  {O(k \log^2 k)} です。

実装例は https://judge.yosupo.jp/problem/polynomial_interpolation にあるんじゃないか

多項式の値

ある次数  {k} 以下の多項式  {f(x)} が存在することがわかっています。また、 {k+1} 個の  {(x_i, f(x_i))} の組が与えられています。このとき、多項式の係数を求める問題ではなくて、 {f(T)} ( {T} は大きめ) を求めたい場合、多項式の係数を陽に復元しないほうが  {O(k^2)} よりも効率的なアルゴリズムを導出できる場合があります。

 {\{f(x_0), f(x_1), \cdots, f(x_k)  \} \Rightarrow f(T) }

 {x} が一般の場合です。

 {\displaystyle f(T) = \sum_{i} y_i l_i (T)} となり、  {l_i (T)} を値として求めれば良くて、係数は求めなくて良いです。

 {\displaystyle l_i(T) = \frac {\prod_{ j \neq i } (T - x_j)} {\prod _{ j \neq i } (x_i - x_j) } } なのでこれに基づいて計算すると、逆元を求める計算量を  {O(\log p)} として  {O(k + \log p)} で求められます。

これを  {k + 1} 本求めるので、全体で  {O(k^2 + k \log p)} となります。 この場合は係数を陽に求める場合と計算量は同じですが、定数倍が軽い気がします。 ( {O(k \log^2 k)} のギャグとの兼ね合いは分からん)

template< typename Mint >
Mint lagrange_polynomial(const vector< Mint > &x, const vector< Mint > &y, const int64_t& T) {
  int k = (int) x.size() - 1;
  Mint ret(0);
  for(int i = 0; i <= k; i++) {
    Mint m = 1, d = 1;
    for(int j = 0; j <= k; j++) {
      if(i != j) {
        m *= Mint(T) - x[j];
        d *= x[i] - x[j];
      }
    }
    ret += y[i] * m / d;
  }
  return ret;
}

 {\{f(0), f(1), \cdots, f(k)  \} \Rightarrow f(T) }

ある多項式が存在して、その次数が  {k} 以下だと分かっていて  {k} が小さい場合に、  {f(0), f(1), \cdots, f(k)} まで DP で求めて、 {f(T)} ( {T} が巨大) を多項式補間により求めるみたいな場面がよくあります(よくあってくれ)。 よくあるので、これに対応するライブラリを持っている人も結構多い気がします。

 {x=T} としたときのそれぞれの値  {l_i(T)} を効率的に求めたいです。

 {\displaystyle l_i(T) = \frac {\prod_{ j \neq i } (T - x_j)} {\prod _{ j \neq i } (x_i - x_j) }} を観察します。 今回考える制約から  {x_i = i} なので  {\displaystyle l_i(T) = \frac {\prod_{ j \neq i } (T - j)} {\prod _{ j \neq i } (i - j) }} です。

分子は  {\displaystyle \prod_{ j \neq i } (T - j)} です。前計算として先頭及び末尾からの  {(T - j)} の累積積を計算しておけば  {O(1)} です。(これは  {x_i} が一般でも適用可能)

分母は  {\displaystyle \prod_{ j \neq i } (i - j)} です。 これは  {(-1)^{k-i} i! (k-i)!} に対応します。前計算として階乗テーブルを求めておけば  {O(1)} です。

 {l_i(T)} {O(1)} で求められるため全体で  {O(k + \log p)} です。

template< typename Mint >
Mint lagrange_polynomial(const vector< Mint > &y, const int64_t& T) {
  int k = (int) y.size() - 1;
  if(T <= k) return y[T];
  Mint ret(0);
  vector< Mint > dp(k + 1, 1), pd(k + 1, 1), finv(k + 1, 1);
  for(int i = 0; i < k; i++) dp[i + 1] = dp[i] * (T - i);
  for(int i = k; i > 0; i--) pd[i - 1] = pd[i] * (T - i);
  for(int i = 2; i <= k; i++) finv[k] *= i;
  finv[k] = Mint(1) / finv[k];
  for(int i = k; i >= 1; i--) finv[i - 1] = finv[i] * i;
  for(int i = 0; i <= k; i++) {
    Mint tmp = y[i] * dp[i] * pd[i] * finv[i] * finv[k - i];
    if((k - i) & 1) ret -= tmp;
    else ret += tmp;
  }
  return ret;
}

 {\{ f(0), f(a), f(2a), \cdots, f(ka) \} \Rightarrow f(T) }

 {x} が項数  {k+1} の等差数列になっている場合です(初項が非  {0} ならそれを引けば  {0} の場合に帰着できます)。

 {f(T)} について求めたい場合は、  {T} {a} で割って  {x = \{0, 1, \cdots, k\} } の場合について解けば良いことがわかり、終わりです。差が一定であることを用いて直接導出することも可能です。

多項式の値たち(標本点のシフト)

 {\{f(x_0), f(x_1), \cdots, f(x_k)  \} \Rightarrow \{ f(a_0), f(a_1), \cdots, f(a_{m-1})\} }

陽にラグランジュ補間多項式を求めて、この多項式に対して、 {1} つずつ  {f(a_i)} を求めると  {O(km)}、multipoint evalution をすると  {O((m + k) \log^2 {(m + k)})} で求まる気がします。(これいる?)

 {\{f(0), f(1), \cdots, f(k)  \} \Rightarrow \{ f(T), f(T+1), \cdots, f(T+m-1)\} }

 {k \lt T} を仮定しています。 {k \ge T} の場合は、その部分について既知の  {f(x)} の値を使うことで  {k \lt T} の場合に帰着できます。

どちらかが単調ではない場合、先に示した遅め or 定数倍大爆発 アルゴリズムを使う必要があると思っています。

ここでは両方とも等差  {1} の数列のときに、より効率的に求めるアルゴリズムを解説します。このアルゴリズムは、例えば階乗を  {O(\sqrt p \log p)} で求めるアルゴリズムに用いられていたりします。 

例によって  {\displaystyle l_i(T+x) = \frac {\prod_{ j \neq i } (T+x - j)} {\prod _{ j \neq i } (i - j) }} を観察します。

分母は  {\displaystyle \prod_{ j \neq i } (i - j)} {x} に依存しないので、階乗テーブルを用いて  {O(k)} で求めておきます。 これらの値の逆数を並べた上で  {y_i} を掛けた列を  {\{d_0, d_1, \cdots, d_k\}} とすると  {f(T+x)} の値は (11) により求められます。

 {\displaystyle f(T + x) = \sum_{i=0}^{k} d_i \prod_{ j \neq i } (T+x - j) \tag {11} }

これを整理します。

 {f(T+x)=\displaystyle (\prod_{j=0}^{k} T+x - j ) ( \sum_{i=0}^{k} \frac {d_i} {T+x-i} ) \tag{12} }

 {\displaystyle \prod_{j=0}^{k} (T+x - j) } の値は  {x=0,1,\cdots,m} に対し、全体で  {O(k+m \log p)} で求められます。  {x} の値から  {x+1} の値に移すときの差分を考えると、 {\displaystyle \frac {T+x+1} {T-k+x}} を掛けることで求まることが確認できます。

したがって  {\displaystyle \sum_{i=0}^{k}  \frac {d_i} {T+x-i} } の値を各  {x=0,1,\cdots,m-1} に対して求めることが本質です。

 {\displaystyle g(i) = d_i (0 \leq i \leq k)} {\displaystyle h(i) = \frac {1} {T-k+i} (0 \leq i \lt k+m)} とします。

これを見ると次の関係がわかります。

 {\displaystyle \sum_{i=0}^{k} g(i) h(k+x-i) = \sum_{i=0}^{k} \frac {d_i} {T+x-i} \tag{13}}

これは、求めたい値そのものです。そして、 {g} {h} を畳み込んだあとの  {k+x} 番目の要素に対応していることが分かります。畳み込みは NTT などを用いて  {O((m + k) \log {(m + k)})} で求められます。

したがって、全体でも  {O((m + k) \log {(m + k)})} で求めることが出来ます。

下の実装例では  {T+x-i=0} となる部分で破滅するので、そこの場合分けもしています。

template< typename Mint, typename F >
vector< Mint > lagrange_polynomial(const vector< Mint > &y, int64_t T, const int &m, const F &multiply) {
  int k = (int) y.size() - 1;
  T %= Mint::get_mod();
  if(T <= k) {
    vector< Mint > ret(m);
    int ptr = 0;
    for(int64_t i = T; i <= k and ptr < m; i++) {
      ret[ptr++] = y[i];
    }
    if(k + 1 < T + m) {
      auto suf = lagrange_polynomial(y, k + 1, m - ptr, multiply);
      for(int i = k + 1; i < T + m; i++) {
        ret[ptr++] = suf[i - (k + 1)];
      }
    }
    return ret;
  }
  if(T + m > Mint::get_mod()) {
    auto pref = lagrange_polynomial(y, T, Mint::get_mod()-T, multiply);
    auto suf = lagrange_polynomial(y, 0, m - pref.size(), multiply);
    copy(begin(suf), end(suf), back_inserter(pref));
    return pref;
  }

  vector< Mint > finv(k + 1, 1), d(k + 1);
  for(int i = 2; i <= k; i++) finv[k] *= i;
  finv[k] = Mint(1) / finv[k];
  for(int i = k; i >= 1; i--) finv[i - 1] = finv[i] * i;
  for(int i = 0; i <= k; i++) {
    d[i] = finv[i] * finv[k - i] * y[i];
    if((k - i) & 1) d[i] = -d[i];
  }

  vector< Mint > h(m + k);
  for(int i = 0; i < m + k; i++) {
    h[i] = Mint(1) / (T - k + i);
  }

  auto dh = multiply(d, h);

  vector< Mint > ret(m);
  Mint cur = T;
  for(int i = 1; i <= k; i++) cur *= T - i;
  for(int i = 0; i < m; i++) {
    ret[i] = cur * dh[k + i];
    cur *= T + i + 1;
    cur *= h[i];
  }
  return ret;
}

multiply は例えば、 ntt を畳み込み用の構造体として、以下を渡せばよいです。

auto multiply = [&](const vector< modint > &a, const vector< modint > &b) {
  return ntt.multiply(a, b);
};

信憑性

ここに示したコードは以下のいずれかでverify済みです。実装が上手いので多分あってますが間違ってたら燃やしてください。

まとめ

ねんね