montgomery modint
乗算がはやーいmodint
Code
#include <iostream>
template<uint32_t M>
struct montgomery_modint {
using Self = montgomery_modint<M>;
using i32 = int32_t;
using u32 = uint32_t;
using u64 = uint64_t;
static constexpr u32 get_r() {
u32 res = M;
for(int i = 0; i < 4; i++) res *= 2 - M * res;
return res;
}
static constexpr u32 reduce(u64 a) {
return (a + u64(u32(a) * u32(-r)) * M) >> 32;
}
static constexpr u32 r = get_r();
static constexpr u32 n2 = -u64(M) % M;
u32 a;
constexpr montgomery_modint() : a(0) {}
constexpr montgomery_modint(int64_t a) : a(reduce(u64(a % M + M) * n2)) {}
constexpr u32 val() const {
u32 res = reduce(a);
return res >= M ? res - M : res;
}
constexpr Self pow(u64 r) const {
Self ans(1); Self aa = *this;
while(r) { if(r & 1) ans *= aa; aa *= aa; r >>= 1; }
return ans;
}
constexpr Self inv() const { return this->pow(M - 2); }
constexpr Self& operator+=(const Self& r) {
if(i32(a += r.a - 2 * M) < 0) a += 2 * M;
return *this;
}
constexpr Self& operator-=(const Self& r) {
if(i32(a -= r.a) < 0) a += 2 * M;
return *this;
}
constexpr Self& operator*=(const Self& r) {
a = reduce(u64(a) * r.a);
return *this;
}
constexpr Self& operator/=(const Self& r) {
*this *= r.inv();
return *this;
}
constexpr Self operator+(const Self r) const { return Self(*this) += r; }
constexpr Self operator-(const Self r) const { return Self(*this) -= r; }
constexpr Self operator-() const { return Self() - Self(*this); }
constexpr Self operator*(const Self r) const { return Self(*this) *= r; }
constexpr Self operator/(const Self r) const { return Self(*this) /= r; }
constexpr bool operator==(const Self& r) const {
return (a >= M ? a - M : a) == (r.a >= M ? r.a - M : r.a);
}
constexpr bool operator!=(const Self& r) const {
return (a >= M ? a - M : a) == (r.a >= M ? r.a - M : r.a);
}
};
template<uint32_t M>
std::ostream& operator<<(std::ostream& os, const montgomery_modint<M>& m) {
return os << m.val();
}
template<uint32_t M>
std::istream& operator>>(std::istream& is, montgomery_modint<M>& m) {
int64_t t;
is >> t;
m = montgomery_modint<M>(t);
return is;
}
template<uint32_t mod>
using modint = montgomery_modint<mod>;