aboutsummaryrefslogtreecommitdiff
path: root/poly_compose.h
blob: 88f339dc7d70c35af7dd8be205622129a57b7b31 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#pragma once

#include "snippets/min_pow_of_two.h"

#include <vector>

template <typename Poly> struct PolyCompose {
  using Mod = typename Poly::Mod;

  // F(G(z))
  Poly operator()(const Poly &f, const Poly &g) {
    int n = f.size();
    int n2 = min_pow_of_two(n + n - 1);
    Poly::reserve(n2);
    int sqrt_n = 1;
    while (sqrt_n * sqrt_n < n) {
      sqrt_n++;
    }
    // sqrt_n * sqrt_n >= n
    const auto dif_g = Poly::template raw_buffer<0>();
    Poly::copy_and_fill0(n2, dif_g, g);
    Poly::dif(n2, dif_g);
    std::vector<Mod> coef(sqrt_n * n);
    // k == 0
    for (int i = 0, offset = 0; i < n; i += sqrt_n, offset += n) {
      coef[offset] += f[i];
    }
    const auto pow_g = Poly::template raw_buffer<1>();
    Poly::copy_and_fill0(n2, pow_g, std::min<int>(n, g.size()), g.data());
    for (int k = 1; k < sqrt_n; k++) {
      for (int i = k, offset = 0; i < n; i += sqrt_n, offset += n) {
        for (int j = 0; j < n; j++) {
          coef[offset + j] += f[i] * pow_g[j];
        }
      }
      Poly::dif(n2, pow_g);
      Poly::dot_product_and_dit(n2, pow_g, pow_g, dif_g);
      std::fill(pow_g + n, pow_g + n2, Mod{0});
    }
    const auto dif_pow_g = pow_g;
    Poly::dif(n2, dif_pow_g);
    const auto res = dif_g;
    std::fill(res, res + n2, Mod{0});
    for (int offset = sqrt_n * n, k = sqrt_n; offset -= n, k--;) {
      Poly::dif(n2, res);
      Poly::dot_product_and_dit(n2, res, res, dif_pow_g);
      std::fill(res + n, res + n2, Mod{0});
      for (int j = 0; j < n; j++) {
        res[j] += coef[offset + j];
      }
    }
    return Poly{res, res + n};
  }
};