This documentation is automatically generated by online-judge-tools/verification-helper

:warning: spec/math/dynamic_mint_spec.cr

Depends on

Code

require "spec"
require "../../src/math/dynamic_mint"

private alias M = DynamicMint
private Init = 998244353

private macro check_binary_operator(mod, op)
  %mod = {{mod}}
  M.mod = %mod
  (0...%mod).each do |x|
    (0...%mod).each do |y|
      v = M.new(x) {{ op.id }} M.new(y)
      e = (x {{ op.id }} y) % %mod
      v.should eq e
    end
  end
end

private macro check_method(mod, method)
  %mod = {{mod}}
  M.mod = %mod
  (0...%mod).each do |x|
    M.new(x).{{ method.id }}.should eq x.{{ method.id }}
  end
end

private macro check_method_mod(mod, method)
  %mod = {{mod}}
  M.mod = %mod
  (0...%mod).each do |x|
    M.new(x).{{ method.id }}.should eq x.{{ method.id }} % %mod
  end
end

describe "DynamicModint" do
  it ".mod and .setmod" do
    M.mod.should eq Init
    M.mod = 42
    M.mod.should eq 42
    M.mod = Int32::MAX
    M.mod.should eq Int32::MAX
    expect_raises(ArgumentError) { M.mod = 0 }
    expect_raises(ArgumentError) { M.mod = -1 }
  end

  it ".zero" do
    M.mod = 60
    M.zero.should eq 0
  end

  it "#==(x)" do
    M.mod = 60
    (M.new(0) == M.new(0)).should be_true
    (M.new(0) == M.new(60)).should be_true
    (M.new(0) == 0).should be_true
    (M.new(0) == 1).should be_false
    (M.new(60) == 0).should be_true
    (M.new(60) == 60).should be_false
  end

  it "#+" do
    M.mod = Init
    (+M.new(1)).should eq 1
    (+M.new(Init)).should eq 0
  end

  it "#-" do
    M.mod = Init
    (-M.new(1)).should eq Init - 1
    (-M.new(0)).should eq 0

    M.mod = 3
    (-M.new(1)).should eq 2
  end

  {% for op in [:+, :-, :*] %}
    it "##{{{ op }}}" do
      check_binary_operator(1, {{ op }})
      check_binary_operator(60, {{ op }})
      check_binary_operator(1009, {{ op }})
    end
  {% end %}

  {% for op in [:/, ://] %}
    it "##{{{ op }}}" do
      {1, 60, 1009}.each do |mod|
        M.mod = mod
        (0...mod).each do |x|
          (0...mod).each do |y|
            next unless y.gcd(mod) == 1
            z = M.new(x) {{ op.id }} y
            (z * y).should eq x
          end
        end
      end
    end
  {% end %}

  it "#**(x)" do
    {1, 60, 1009}.each do |mod|
      M.mod = mod
      (0i64...mod.to_i64).each do |x|
        0i64.step(to: 10i64**18, by: 10i64**16) do |e|
          (M.new(x) ** e).should eq AtCoder::Math.pow_mod(x, e, mod)
        end
      end
    end
  end

  it "#inv" do
    {1, 60, 1009, 1000003}.each do |mod|
      M.mod = mod
      (1...mod).each do |x|
        next unless x.gcd(mod) == 1
        (M.new(x).inv * x).should eq 1
      end
    end
  end

  {% for method in [:succ, :pred, :abs, :to_i64] %}
    it "##{{{ method }}}" do
      check_method_mod(1, {{ method }})
      check_method_mod(60, {{ method }})
      check_method_mod(1009, {{ method }})
      check_method_mod(1000003, {{ method }})
    end
  {% end %}

  {% for method in [:to_s, :inspect] %}
    it "##{{{ method }}}" do
      check_method(1, {{ method }})
      check_method(60, {{ method }})
      check_method(1009, {{ method }})
      check_method(1000003, {{ method }})
    end
  {% end %}

  it "compares" do
    expect_raises(NotImplementedError) { M.new(0) < 0 }
    expect_raises(NotImplementedError) { M.new(0) <= 0 }
    expect_raises(NotImplementedError) { M.new(0) > 0 }
    expect_raises(NotImplementedError) { M.new(0) >= 0 }
  end
end
require "spec"

# require "../../src/math/dynamic_mint"
# require "./barrett"
struct Barrett
  getter mod : UInt32, inv : UInt64

  # Requires `1 <= mod < 2^31`
  def initialize(@mod)
    @inv = UInt64::MAX // @mod &+ 1
  end

  # Caluclates `a * b % mod`.
  #
  # Requires `0 <= a < mod` and `0 <= b < mod`
  def mul(a : UInt32, b : UInt32) : UInt32
    z = a.to_u64! &* b
    x = ((z.to_u128! &* @inv) >> 64).to_u64!
    v = (z &- x &* @mod).to_u32!
    v &+= @mod if @mod <= v
    v
  end
end

# require "../../atcoder/src/Math"
# ac-library.cr by hakatashi https://github.com/google/ac-library.cr
#
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

module AtCoder
  # Implements [ACL's Math library](https://atcoder.github.io/ac-library/master/document_en/math.html)
  module Math
    def self.extended_gcd(a, b)
      last_remainder, remainder = a.abs, b.abs
      x, last_x, y, last_y = 0_i64, 1_i64, 1_i64, 0_i64
      while remainder != 0
        new_last_remainder = remainder
        quotient, remainder = last_remainder.divmod(remainder)
        last_remainder = new_last_remainder
        x, last_x = last_x - quotient * x, x
        y, last_y = last_y - quotient * y, y
      end

      return last_remainder, last_x * (a < 0 ? -1 : 1)
    end

    # Implements atcoder::inv_mod(value, modulo).
    def self.inv_mod(value, modulo)
      gcd, inv = extended_gcd(value, modulo)
      if gcd != 1
        raise ArgumentError.new("#{value} and #{modulo} are not coprime")
      end
      inv % modulo
    end

    # Simplified AtCoder::Math.pow_mod with support of Int64
    def self.pow_mod(base, exponent, modulo)
      if exponent == 0
        return base.class.zero + 1
      end
      if base == 0
        return base
      end
      b = exponent > 0 ? base : inv_mod(base, modulo)
      e = exponent.abs
      ret = 1_i64
      while e > 0
        if e % 2 == 1
          ret = mul_mod(ret, b, modulo)
        end
        b = mul_mod(b, b, modulo)
        e //= 2
      end
      ret
    end

    # Caluculates a * b % mod without overflow detection
    @[AlwaysInline]
    def self.mul_mod(a : Int64, b : Int64, mod : Int64)
      if mod < Int32::MAX
        return a * b % mod
      end

      # 31-bit width
      a_high = (a >> 32).to_u64
      # 32-bit width
      a_low = (a & 0xFFFFFFFF).to_u64
      # 31-bit width
      b_high = (b >> 32).to_u64
      # 32-bit width
      b_low = (b & 0xFFFFFFFF).to_u64

      # 31-bit + 32-bit + 1-bit = 64-bit
      c = a_high * b_low + b_high * a_low
      c_high = c >> 32
      c_low = c & 0xFFFFFFFF

      # 31-bit + 31-bit
      res_high = a_high * b_high + c_high
      # 32-bit + 32-bit
      res_low = a_low * b_low
      res_low_high = res_low >> 32
      res_low_low = res_low & 0xFFFFFFFF

      # Overflow
      if res_low_high + c_low >= 0x100000000
        res_high += 1
      end

      res_low = (((res_low_high + c_low) & 0xFFFFFFFF) << 32) | res_low_low

      (((res_high.to_i128 << 64) | res_low) % mod).to_i64
    end

    @[AlwaysInline]
    def self.mul_mod(a, b, mod)
      typeof(mod).new(a.to_i64 * b % mod)
    end

    # Implements atcoder::crt(remainders, modulos).
    def self.crt(remainders, modulos)
      raise ArgumentError.new unless remainders.size == modulos.size

      total_modulo = 1_i64
      answer = 0_i64

      remainders.zip(modulos).each do |(remainder, modulo)|
        gcd, p = extended_gcd(total_modulo, modulo)
        if (remainder - answer) % gcd != 0
          return 0_i64, 0_i64
        end
        tmp = (remainder - answer) // gcd * p % (modulo // gcd)
        answer += total_modulo * tmp
        total_modulo *= modulo // gcd
      end

      return answer % total_modulo, total_modulo
    end

    # Implements atcoder::floor_sum(n, m, a, b).
    def self.floor_sum(n, m, a, b)
      n, m, a, b = n.to_i64, m.to_i64, a.to_i64, b.to_i64
      res = 0_i64

      if a < 0
        a2 = a % m
        res -= n * (n - 1) // 2 * ((a2 - a) // m)
        a = a2
      end

      if b < 0
        b2 = b % m
        res -= n * ((b2 - b) // m)
        b = b2
      end

      res + floor_sum_unsigned(n, m, a, b)
    end

    private def self.floor_sum_unsigned(n, m, a, b)
      res = 0_i64

      loop do
        if a >= m
          res += n * (n - 1) // 2 * (a // m)
          a = a % m
        end

        if b >= m
          res += n * (b // m)
          b = b % m
        end

        y_max = a * n + b
        break if y_max < m

        n = y_max // m
        b = y_max % m
        m, a = a, m
      end

      res
    end
  end
end

struct DynamicMint
  @@bt : Barrett = Barrett.new(998244353u32)
  alias Mint = DynamicMint

  def self.mod : Int32
    @@bt.mod.to_i
  end

  def self.mod=(m : Int32)
    raise ArgumentError.new unless 1 <= m
    @@bt = Barrett.new(m.to_u32)
  end

  def self.raw(v : Int32)
    result = Mint.new
    result.value = v
    result
  end

  def self.raw(v)
    result = Mint.new
    result.value = v.to_i!
    result
  end

  def self.zero
    Mint.new
  end

  getter value : Int32

  protected def value=(v : Int32)
    @value = v
  end

  def initialize
    @value = 0
  end

  def initialize(x : self)
    @value = x.value
  end

  def initialize(x : Int)
    @value = x.to_i % Mint.mod
  end

  def ==(v : self)
    value == v.value
  end

  def ==(v)
    value == v
  end

  def +
    self
  end

  def -
    Mint.raw value == 0 ? 0 : Mint.mod &- value
  end

  def +(v)
    x = value &+ Mint.new(v).value
    x &-= Mint.mod if x >= Mint.mod
    Mint.raw x
  end

  def -(v)
    x = value &- Mint.new(v).value
    x &+= Mint.mod if x < 0
    Mint.raw x
  end

  def *(v)
    Mint.raw @@bt.mul(value.to_u!, Mint.new(v).value.to_u!)
  end

  def /(v)
    self * Mint.new(v).inv
  end

  def //(v)
    self / v
  end

  def **(exponent : Int)
    t, res = self, Mint.raw(1)
    while exponent > 0
      res *= t if exponent & 1 == 1
      t *= t
      exponent >>= 1
    end
    res
  end

  def inv
    Mint.raw AtCoder::Math.inv_mod(value, Mint.mod)
  end

  def succ
    Mint.raw value == Mint.mod &- 1 ? 0 : value &+ 1
  end

  def pred
    Mint.raw value == 0 ? Mint.mod &- 1 : value &- 1
  end

  def abs
    self
  end

  def abs2
    self * self
  end

  def to_i64 : Int64
    value.to_i64
  end

  def to_s(io : IO) : Nil
    value.to_s(io)
  end

  def inspect(io : IO) : Nil
    value.inspect(io)
  end

  {% for op in %w[< <= > >=] %}
    def {{op.id}}(other)
      raise NotImplementedError.new({{op}})
    end
  {% end %}
end

private alias M = DynamicMint
private Init = 998244353

private macro check_binary_operator(mod, op)
  %mod = {{mod}}
  M.mod = %mod
  (0...%mod).each do |x|
    (0...%mod).each do |y|
      v = M.new(x) {{ op.id }} M.new(y)
      e = (x {{ op.id }} y) % %mod
      v.should eq e
    end
  end
end

private macro check_method(mod, method)
  %mod = {{mod}}
  M.mod = %mod
  (0...%mod).each do |x|
    M.new(x).{{ method.id }}.should eq x.{{ method.id }}
  end
end

private macro check_method_mod(mod, method)
  %mod = {{mod}}
  M.mod = %mod
  (0...%mod).each do |x|
    M.new(x).{{ method.id }}.should eq x.{{ method.id }} % %mod
  end
end

describe "DynamicModint" do
  it ".mod and .setmod" do
    M.mod.should eq Init
    M.mod = 42
    M.mod.should eq 42
    M.mod = Int32::MAX
    M.mod.should eq Int32::MAX
    expect_raises(ArgumentError) { M.mod = 0 }
    expect_raises(ArgumentError) { M.mod = -1 }
  end

  it ".zero" do
    M.mod = 60
    M.zero.should eq 0
  end

  it "#==(x)" do
    M.mod = 60
    (M.new(0) == M.new(0)).should be_true
    (M.new(0) == M.new(60)).should be_true
    (M.new(0) == 0).should be_true
    (M.new(0) == 1).should be_false
    (M.new(60) == 0).should be_true
    (M.new(60) == 60).should be_false
  end

  it "#+" do
    M.mod = Init
    (+M.new(1)).should eq 1
    (+M.new(Init)).should eq 0
  end

  it "#-" do
    M.mod = Init
    (-M.new(1)).should eq Init - 1
    (-M.new(0)).should eq 0

    M.mod = 3
    (-M.new(1)).should eq 2
  end

  {% for op in [:+, :-, :*] %}
    it "##{{{ op }}}" do
      check_binary_operator(1, {{ op }})
      check_binary_operator(60, {{ op }})
      check_binary_operator(1009, {{ op }})
    end
  {% end %}

  {% for op in [:/, ://] %}
    it "##{{{ op }}}" do
      {1, 60, 1009}.each do |mod|
        M.mod = mod
        (0...mod).each do |x|
          (0...mod).each do |y|
            next unless y.gcd(mod) == 1
            z = M.new(x) {{ op.id }} y
            (z * y).should eq x
          end
        end
      end
    end
  {% end %}

  it "#**(x)" do
    {1, 60, 1009}.each do |mod|
      M.mod = mod
      (0i64...mod.to_i64).each do |x|
        0i64.step(to: 10i64**18, by: 10i64**16) do |e|
          (M.new(x) ** e).should eq AtCoder::Math.pow_mod(x, e, mod)
        end
      end
    end
  end

  it "#inv" do
    {1, 60, 1009, 1000003}.each do |mod|
      M.mod = mod
      (1...mod).each do |x|
        next unless x.gcd(mod) == 1
        (M.new(x).inv * x).should eq 1
      end
    end
  end

  {% for method in [:succ, :pred, :abs, :to_i64] %}
    it "##{{{ method }}}" do
      check_method_mod(1, {{ method }})
      check_method_mod(60, {{ method }})
      check_method_mod(1009, {{ method }})
      check_method_mod(1000003, {{ method }})
    end
  {% end %}

  {% for method in [:to_s, :inspect] %}
    it "##{{{ method }}}" do
      check_method(1, {{ method }})
      check_method(60, {{ method }})
      check_method(1009, {{ method }})
      check_method(1000003, {{ method }})
    end
  {% end %}

  it "compares" do
    expect_raises(NotImplementedError) { M.new(0) < 0 }
    expect_raises(NotImplementedError) { M.new(0) <= 0 }
    expect_raises(NotImplementedError) { M.new(0) > 0 }
    expect_raises(NotImplementedError) { M.new(0) >= 0 }
  end
end
Back to top page