#include "mersenne_twister_64.h"

#define MM 156
#define MATRIX_A 0xb5026f5aa96619e9ull
#define UM 0xffffffff80000000ull
#define LM 0x7fffffffull

void kit_mt64_init_array(kit_mt64_state_t *state, i64 size,
                         u64 *seed) {
  i64 i;
  for (i = 0; i < size && i < KIT_MT64_N; i++) state->mt[i] = seed[i];
  for (state->index = size; state->index < KIT_MT64_N; state->index++)
    state->mt[state->index] = (6364136223846793005ull *
                                   (state->mt[state->index - 1] ^
                                    (state->mt[state->index - 1] >>
                                     62u)) +
                               state->index);
}

void kit_mt64_init(kit_mt64_state_t *state, u64 seed) {
  kit_mt64_init_array(state, 1, &seed);
}

void kit_mt64_rotate(kit_mt64_state_t *state) {
  static u64 mag01[2] = { 0ull, MATRIX_A };

  u64 x;
  i32 i;

  for (i = 0; i < KIT_MT64_N - MM; i++) {
    x            = (state->mt[i] & UM) | (state->mt[i + 1] & LM);
    state->mt[i] = state->mt[i + MM] ^ (x >> 1u) ^
                   mag01[(i32) (x & 1ull)];
  }

  for (; i < KIT_MT64_N - 1; i++) {
    x            = (state->mt[i] & UM) | (state->mt[i + 1] & LM);
    state->mt[i] = state->mt[i + (MM - KIT_MT64_N)] ^ (x >> 1u) ^
                   mag01[(i32) (x & 1ull)];
  }

  x = (state->mt[KIT_MT64_N - 1] & UM) | (state->mt[0] & LM);
  state->mt[KIT_MT64_N - 1] = state->mt[MM - 1] ^ (x >> 1u) ^
                              mag01[(i32) (x & 1ull)];

  state->index = 0;
}

u64 kit_mt64_generate(kit_mt64_state_t *state) {
  if (state->index >= KIT_MT64_N)
    kit_mt64_rotate(state);

  u64 x = state->mt[state->index++];

  x ^= (x >> 29u) & 0x5555555555555555ull;
  x ^= (x << 17u) & 0x71d67fffeda60000ull;
  x ^= (x << 37u) & 0xfff7eee000000000ull;
  x ^= (x >> 43u);

  return x;
}

#undef MM
#undef MATRIX_A
#undef UM
#undef LM