Numeric Theoretic Transform
Spec
素数\( m = 2^k + 1, 2^k >= n\)として\(F_m \)剰余環上での離散フーリエ変換(discrete Fourier transformation)を\( O(n \log n) \)で行う
NTT_PRIMES
はるまライブラリからお借りしています
ほとんど単体でNTTを使うことはなさそう(FPSで使うね)
Code
#include "../modint.hpp"
#include <vector>
using namespace std;
constexpr i64 NTT_PRIMES[][2] = {
{1224736769, 3}, // 2^24 * 73 + 1,
{1053818881, 7}, // 2^20 * 3 * 5 * 67 + 1
{1051721729, 6}, // 2^20 * 17 * 59 + 1
{1045430273, 3}, // 2^20 * 997 + 1
{1012924417, 5}, // 2^21 * 3 * 7 * 23 + 1
{1007681537, 3}, // 2^20 * 31^2 + 1
{1004535809, 3}, // 2^21 * 479 + 1
{998244353, 3}, // 2^23 * 7 * 17 + 1
{985661441, 3}, // 2^22 * 5 * 47 + 1
{976224257, 3}, // 2^20 * 7^2 * 19 + 1
{975175681, 17}, // 2^21 * 3 * 5 * 31 + 1
{962592769, 7}, // 2^21 * 3^3 * 17 + 1
{950009857, 7}, // 2^21 * 4 * 151 + 1
{943718401, 7}, // 2^22 * 3^2 * 5^2 + 1
{935329793, 3}, // 2^22 * 223 + 1
{924844033, 5}, // 2^21 * 3^2 * 7^2 + 1
{469762049, 3}, // 2^26 * 7 + 1
{167772161, 3}, // 2^25 * 5 + 1
};
template<const i64 mod, const i64 primitive>
vector<modint<mod>> number_theoretic_transform(vector<modint<mod>> a) {
i64 n = a.size();
for(i64 s = n >> 1; s >= 1; s >>= 1) {
modint<mod> zeta = modint<mod>(primitive).pow((mod - 1) / (s << 1));
for(i64 i = 0; i < n; i += (s << 1)) {
modint<mod> zi = 1;
for(i64 j = 0;j < s;j++) {
modint<mod> t = a[i + j] - a[s + i + j];
a[i + j] += a[s + i + j];
a[s + i + j] = t * zi;
zi = zi * zeta;
}
}
}
return a;
}
template<const i64 mod, const i64 primitive>
vector<modint<mod>> inverse_number_theoretic_transform(vector<modint<mod>> a) {
i64 n = a.size();
for(i64 s = 1; s < n; s <<= 1) {
modint<mod> zeta = modint<mod>(primitive).pow((mod - 1) / (s << 1)).pow(mod - 2);
for(i64 i = 0; i < n; i += (s << 1)) {
modint<mod> zi = 1;
for(i64 j = 0;j < s;j++) {
modint<mod> t = a[s + i + j] * zi;
a[s + i + j] = a[i + j] - t;
a[i + j] = a[i + j] + t;
zi = zi * zeta;
}
}
}
auto inv_n = modint<mod>(n).pow(mod - 2);
for(int i = 0;i < n;i++) a[i] *= inv_n;
return a;
}
4基底NTT
#include "../modint.hpp"
#include <vector>
using namespace std;
using i64 = long long;
constexpr i64 NTT_PRIMES[][2] = {
{1224736769, 3}, // 2^24 * 73 + 1,
{1053818881, 7}, // 2^20 * 3 * 5 * 67 + 1
{1051721729, 6}, // 2^20 * 17 * 59 + 1
{1045430273, 3}, // 2^20 * 997 + 1
{1012924417, 5}, // 2^21 * 3 * 7 * 23 + 1
{1007681537, 3}, // 2^20 * 31^2 + 1
{1004535809, 3}, // 2^21 * 479 + 1
{998244353, 3}, // 2^23 * 7 * 17 + 1
{985661441, 3}, // 2^22 * 5 * 47 + 1
{976224257, 3}, // 2^20 * 7^2 * 19 + 1
{975175681, 17}, // 2^21 * 3 * 5 * 31 + 1
{962592769, 7}, // 2^21 * 3^3 * 17 + 1
{950009857, 7}, // 2^21 * 4 * 151 + 1
{943718401, 7}, // 2^22 * 3^2 * 5^2 + 1
{935329793, 3}, // 2^22 * 223 + 1
{924844033, 5}, // 2^21 * 3^2 * 7^2 + 1
{469762049, 3}, // 2^26 * 7 + 1
{167772161, 3}, // 2^25 * 5 + 1
};
template<const i64 mod, const i64 primitive>
vector<modint<mod>> number_theoretic_transform4(vector<modint<mod>> a) {
i64 n = a.size();
vector<modint<mod>> b(a.size());
auto unit_i = modint<mod>(primitive).pow((mod - 1) / 4);
for(i64 s = 1, m = n; s < n; s <<= 1, m >>= 1) {
if(m == 2) {
for(i64 j = 0;j < s;j++) {
auto x = a[j + 0];
auto y = a[j + s];
b[j + 0] = x + y;
b[j + s] = x - y;
}
}
else {
modint<mod> zi1 = 1;
modint<mod> zi2 = 1;
modint<mod> zi3 = 1;
i64 m1 = m >> 2;
i64 m2 = m >> 1;
i64 m3 = m1 | m2;
modint<mod> zeta = modint<mod>(primitive).pow((mod - 1) / m);
for(i64 i = 0;i < m1;i++) {
for(i64 j = 0;j < s;j++) {
auto w = a[j + s * (i + 0)];
auto x = a[j + s * (i + m1)];
auto y = a[j + s * (i + m2)];
auto z = a[j + s * (i + m3)];
auto wy1 = w + y;
auto wy2 = w - y;
auto xz1 = x + z;
auto xz2 = (x - z) * unit_i;
b[j + s * (4 * i + 0)] = wy1 + xz1;
b[j + s * (4 * i + 1)] = (wy2 + xz2) * zi1;
b[j + s * (4 * i + 2)] = (wy1 - xz1) * zi2;
b[j + s * (4 * i + 3)] = (wy2 - xz2) * zi3;
}
zi1 = zi1 * zeta;
zi2 = zi1 * zi1;
zi3 = zi1 * zi2;
}
s <<= 1;
m >>= 1;
}
swap(a, b);
}
return a;
}
template<const i64 mod, const i64 primitive>
vector<modint<mod>> inverse_number_theoretic_transform4(vector<modint<mod>> a) {
i64 n = a.size();
vector<modint<mod>> b(a.size());
auto unit_i = modint<mod>(primitive).pow((mod - 1) / 4).inv();
i64 s = n;
i64 m = 1;
if(__builtin_ctzll(n) & 1) {
s >>= 1;
m <<= 1;
for(i64 j = 0;j < s;j++) {
auto x = a[j + 0];
auto y = a[j + s];
b[j + 0] = x + y;
b[j + s] = x - y;
}
swap(a, b);
}
for(; s >>= 2, m <<= 2, s >= 1;) {
{
modint<mod> zi1 = 1;
modint<mod> zi2 = 1;
modint<mod> zi3 = 1;
i64 m1 = m >> 2;
i64 m2 = m >> 1;
i64 m3 = m1 | m2;
modint<mod> zeta = modint<mod>(primitive).pow((mod - 1) / m).inv();
for(i64 i = 0;i < m1;i++) {
for(i64 j = 0;j < s;j++) {
auto w = a[j + s * (4 * i + 0)];
auto x = a[j + s * (4 * i + 1)] * zi1;
auto y = a[j + s * (4 * i + 2)] * zi2;
auto z = a[j + s * (4 * i + 3)] * zi3;
auto wy1 = w + y;
auto wy2 = x + z;
auto xz1 = w - y;
auto xz2 = (x - z) * unit_i;
b[j + s * (i + 0)] = wy1 + wy2;
b[j + s * (i + m1)] = xz1 + xz2;
b[j + s * (i + m2)] = wy1 - wy2;
b[j + s * (i + m3)] = xz1 - xz2;
}
zi1 = zi1 * zeta;
zi2 = zi1 * zi1;
zi3 = zi1 * zi2;
}
}
swap(a, b);
}
auto inv_n = modint<mod>(n).pow(mod - 2);
for(int i = 0;i < n;i++) a[i] *= inv_n;
return a;
}