From 0223e2faaebcb4065d9c37f55e8798196f911ae9 Mon Sep 17 00:00:00 2001
From: Mitya Selivanov <automainint@guattari.tech>
Date: Sun, 11 Dec 2022 23:46:56 +0100
Subject: [bigint] Big endian correctness

---
 source/kit/bigint.h                 | 225 ++++++++++++++++++++----------------
 source/test/unittests/bigint.test.c |   2 +-
 2 files changed, 126 insertions(+), 101 deletions(-)

(limited to 'source')

diff --git a/source/kit/bigint.h b/source/kit/bigint.h
index 30233c6..7bb6a24 100644
--- a/source/kit/bigint.h
+++ b/source/kit/bigint.h
@@ -13,9 +13,6 @@ extern "C" {
 #  define KIT_BIGINT_SIZE 64
 #endif
 
-#define KIT_UWORD_MAX ((uint_fast32_t) -1)
-#define KIT_UWORD_SIZE sizeof(uint_fast32_t)
-
 static_assert(sizeof(uint8_t) == 1, "uint8_t size should be 1 byte");
 static_assert(sizeof(uint32_t) == 4,
               "uint32_t size should be 4 bytes");
@@ -25,24 +22,10 @@ static_assert(KIT_BIGINT_SIZE > 0 && (KIT_BIGINT_SIZE % 8) == 0,
               "Invalid big integer size");
 
 typedef struct {
-  union {
-    struct {
-      uint8_t v8[KIT_BIGINT_SIZE];
-    };
-    struct {
-      uint32_t v32[KIT_BIGINT_SIZE / 4];
-    };
-    struct {
-      uint64_t v64[KIT_BIGINT_SIZE / 8];
-    };
-    struct {
-      uint_fast32_t v[KIT_BIGINT_SIZE / KIT_UWORD_SIZE];
-    };
-  };
+  uint32_t v[KIT_BIGINT_SIZE / 4];
 } kit_bigint_t;
 
-typedef uint_fast32_t kit_uword_t;
-typedef uint_fast8_t  kit_bit_t;
+typedef uint_fast8_t kit_bit_t;
 
 #ifdef __GNUC__
 #  pragma GCC diagnostic push
@@ -52,29 +35,29 @@ typedef uint_fast8_t  kit_bit_t;
 #  pragma GCC            optimize("O3")
 #endif
 
-static kit_bigint_t kit_bi_uword(kit_uword_t const x) {
+static kit_bigint_t kit_bi_uint32(uint32_t const x) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
+  memset(&z, 0, sizeof z);
   z.v[0] = x;
   return z;
 }
 
 static kit_bigint_t kit_bi_uint64(uint64_t const x) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
-  z.v64[0] = x;
+  memset(&z, 0, sizeof z);
+  z.v[0] = (uint32_t) (x & 0xffffffff);
+  z.v[1] = (uint32_t) (x >> 32);
   return z;
 }
 
 static int kit_bi_equal(kit_bigint_t const x, kit_bigint_t const y) {
-  return kit_ar_equal_bytes(1, KIT_BIGINT_SIZE, x.v8, 1,
-                            KIT_BIGINT_SIZE, y.v8);
+  return kit_ar_equal_bytes(1, KIT_BIGINT_SIZE, x.v, 1,
+                            KIT_BIGINT_SIZE, y.v);
 }
 
 static int kit_bi_compare(kit_bigint_t const x,
                           kit_bigint_t const y) {
-  for (ptrdiff_t i = KIT_BIGINT_SIZE / KIT_UWORD_SIZE - 1; i >= 0;
-       i--)
+  for (ptrdiff_t i = KIT_BIGINT_SIZE / 4 - 1; i >= 0; i--)
     if (x.v[i] < y.v[i])
       return -1;
     else if (x.v[i] > y.v[i])
@@ -83,32 +66,56 @@ static int kit_bi_compare(kit_bigint_t const x,
 }
 
 static ptrdiff_t kit_bi_significant_bit_count(kit_bigint_t const x) {
-  ptrdiff_t bytes = KIT_BIGINT_SIZE - 1;
+  ptrdiff_t n = KIT_BIGINT_SIZE / 4 - 1;
 
-  while (bytes > 0 && x.v8[bytes] == 0) bytes--;
+  while (n > 0 && x.v[n] == 0) n--;
 
-  uint8_t const byte = x.v8[bytes];
+  uint32_t const i32 = x.v[n];
 
-  if (byte == 0)
+  if (i32 == 0)
     return 0;
 
-  ptrdiff_t const bits = (byte & 0x80) != 0   ? 8
-                         : (byte & 0x40) != 0 ? 7
-                         : (byte & 0x20) != 0 ? 6
-                         : (byte & 0x10) != 0 ? 5
-                         : (byte & 0x08) != 0 ? 4
-                         : (byte & 0x04) != 0 ? 3
-                         : (byte & 0x02) != 0 ? 2
-                                              : 1;
-
-  return bytes * 8 + bits;
+  ptrdiff_t const bits = (i32 & 0x80000000u) != 0   ? 32
+                         : (i32 & 0x40000000u) != 0 ? 31
+                         : (i32 & 0x20000000u) != 0 ? 30
+                         : (i32 & 0x10000000u) != 0 ? 29
+                         : (i32 & 0x8000000u) != 0  ? 28
+                         : (i32 & 0x4000000u) != 0  ? 27
+                         : (i32 & 0x2000000u) != 0  ? 26
+                         : (i32 & 0x1000000u) != 0  ? 25
+                         : (i32 & 0x800000u) != 0   ? 24
+                         : (i32 & 0x400000u) != 0   ? 23
+                         : (i32 & 0x200000u) != 0   ? 22
+                         : (i32 & 0x100000u) != 0   ? 21
+                         : (i32 & 0x80000u) != 0    ? 20
+                         : (i32 & 0x40000u) != 0    ? 19
+                         : (i32 & 0x20000u) != 0    ? 18
+                         : (i32 & 0x10000u) != 0    ? 17
+                         : (i32 & 0x8000u) != 0     ? 16
+                         : (i32 & 0x4000u) != 0     ? 15
+                         : (i32 & 0x2000u) != 0     ? 14
+                         : (i32 & 0x1000u) != 0     ? 13
+                         : (i32 & 0x800u) != 0      ? 12
+                         : (i32 & 0x400u) != 0      ? 11
+                         : (i32 & 0x200u) != 0      ? 10
+                         : (i32 & 0x100u) != 0      ? 9
+                         : (i32 & 0x80u) != 0       ? 8
+                         : (i32 & 0x40u) != 0       ? 7
+                         : (i32 & 0x20u) != 0       ? 6
+                         : (i32 & 0x10u) != 0       ? 5
+                         : (i32 & 0x08u) != 0       ? 4
+                         : (i32 & 0x04u) != 0       ? 3
+                         : (i32 & 0x02u) != 0       ? 2
+                                                    : 1;
+
+  return n * 32 + bits;
 }
 
 static kit_bigint_t kit_bi_and(kit_bigint_t const x,
                                kit_bigint_t const y) {
   kit_bigint_t z;
 
-  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / KIT_UWORD_SIZE; i++)
+  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++)
     z.v[i] = x.v[i] & y.v[i];
 
   return z;
@@ -118,7 +125,7 @@ static kit_bigint_t kit_bi_or(kit_bigint_t const x,
                               kit_bigint_t const y) {
   kit_bigint_t z;
 
-  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / KIT_UWORD_SIZE; i++)
+  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++)
     z.v[i] = x.v[i] | y.v[i];
 
   return z;
@@ -128,54 +135,50 @@ static kit_bigint_t kit_bi_xor(kit_bigint_t const x,
                                kit_bigint_t const y) {
   kit_bigint_t z;
 
-  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / KIT_UWORD_SIZE; i++)
+  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++)
     z.v[i] = x.v[i] ^ y.v[i];
 
   return z;
 }
 
-static kit_bigint_t kit_bi_shl_uword(kit_bigint_t const x,
-                                     kit_uword_t        y) {
+static kit_bigint_t kit_bi_shl_uint(kit_bigint_t const x,
+                                    uint32_t const     y) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
+  memset(&z, 0, sizeof z);
 
-  ptrdiff_t const words = (ptrdiff_t) (y / (8 * KIT_UWORD_SIZE));
-  ptrdiff_t const bits  = (ptrdiff_t) (y % (8 * KIT_UWORD_SIZE));
+  ptrdiff_t const words = (ptrdiff_t) (y / 32);
+  ptrdiff_t const bits  = (ptrdiff_t) (y % 32);
 
-  for (ptrdiff_t i = words; i < KIT_BIGINT_SIZE / KIT_UWORD_SIZE;
-       i++) {
+  for (ptrdiff_t i = words; i < KIT_BIGINT_SIZE / 4; i++) {
     z.v[i] |= x.v[i - words] << bits;
-    if (bits != 0 && i + 1 < KIT_BIGINT_SIZE / KIT_UWORD_SIZE)
-      z.v[i + 1] = x.v[i - words] >> (8 * KIT_UWORD_SIZE - bits);
+    if (bits != 0 && i + 1 < KIT_BIGINT_SIZE / 4)
+      z.v[i + 1] = x.v[i - words] >> (32 - bits);
   }
 
   return z;
 }
 
-static kit_bigint_t kit_bi_shr_uword(kit_bigint_t const x,
-                                     kit_uword_t        y) {
+static kit_bigint_t kit_bi_shr_uint(kit_bigint_t const x,
+                                    uint32_t const     y) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
+  memset(&z, 0, sizeof z);
 
-  ptrdiff_t const words = (ptrdiff_t) (y / (8 * KIT_UWORD_SIZE));
-  ptrdiff_t const bits  = (ptrdiff_t) (y % (8 * KIT_UWORD_SIZE));
+  ptrdiff_t const words = (ptrdiff_t) (y / 32);
+  ptrdiff_t const bits  = (ptrdiff_t) (y % 32);
 
-  for (ptrdiff_t i = KIT_BIGINT_SIZE / KIT_UWORD_SIZE - words - 1;
-       i >= 0; i--) {
+  for (ptrdiff_t i = KIT_BIGINT_SIZE / 4 - words - 1; i >= 0; i--) {
     z.v[i] |= x.v[i + words] >> bits;
     if (bits != 0 && i > 0)
-      z.v[i - 1] = x.v[i + words] << (8 * KIT_UWORD_SIZE - bits);
+      z.v[i - 1] = x.v[i + words] << (32 - bits);
   }
 
   return z;
 }
 
-static kit_bit_t kit_bi_carry(kit_uword_t const x,
-                              kit_uword_t const y,
-                              kit_bit_t const   carry) {
+static kit_bit_t kit_bi_carry(uint32_t const x, uint32_t const y,
+                              kit_bit_t const carry) {
   assert(carry == 0 || carry == 1);
-  return KIT_UWORD_MAX - x < y || KIT_UWORD_MAX - x - y < carry ? 1
-                                                                : 0;
+  return 0xffffffffu - x < y || 0xffffffffu - x - y < carry ? 1 : 0;
 }
 
 static kit_bigint_t kit_bi_add(kit_bigint_t const x,
@@ -183,7 +186,7 @@ static kit_bigint_t kit_bi_add(kit_bigint_t const x,
   kit_bigint_t z;
   kit_bit_t    carry = 0;
 
-  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / KIT_UWORD_SIZE; i++) {
+  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++) {
     z.v[i] = x.v[i] + y.v[i] + carry;
     carry  = kit_bi_carry(x.v[i], y.v[i], carry);
   }
@@ -195,9 +198,9 @@ static kit_bigint_t kit_bi_neg(kit_bigint_t const x) {
   kit_bigint_t y;
   kit_bit_t    carry = 1;
 
-  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / KIT_UWORD_SIZE; i++) {
-    y.v[i] = (x.v[i] ^ KIT_UWORD_MAX) + carry;
-    carry  = kit_bi_carry(x.v[i] ^ KIT_UWORD_MAX, 0, carry);
+  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++) {
+    y.v[i] = (x.v[i] ^ 0xffffffff) + carry;
+    carry  = kit_bi_carry(x.v[i] ^ 0xffffffff, 0, carry);
   }
 
   return y;
@@ -208,9 +211,9 @@ static kit_bigint_t kit_bi_sub(kit_bigint_t const x,
   kit_bigint_t z;
   kit_bit_t    carry = 1;
 
-  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / KIT_UWORD_SIZE; i++) {
-    z.v[i] = x.v[i] + (y.v[i] ^ KIT_UWORD_MAX) + carry;
-    carry  = kit_bi_carry(x.v[i], (y.v[i] ^ KIT_UWORD_MAX), carry);
+  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++) {
+    z.v[i] = x.v[i] + (y.v[i] ^ 0xffffffff) + carry;
+    carry  = kit_bi_carry(x.v[i], (y.v[i] ^ 0xffffffff), carry);
   }
 
   return z;
@@ -219,19 +222,19 @@ static kit_bigint_t kit_bi_sub(kit_bigint_t const x,
 static kit_bigint_t kit_bi_mul_uint32(kit_bigint_t const x,
                                       uint32_t const     y) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
+  memset(&z, 0, sizeof z);
 
   if (y != 0)
     for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++) {
-      if (x.v32[i] == 0)
+      if (x.v[i] == 0)
         continue;
 
-      uint64_t carry = ((uint64_t) x.v32[i]) * ((uint64_t) y);
+      uint64_t carry = ((uint64_t) x.v[i]) * ((uint64_t) y);
 
       for (ptrdiff_t k = i; k < KIT_BIGINT_SIZE / 4 && carry != 0;
            k++) {
-        uint64_t const sum = ((uint64_t) z.v32[k]) + carry;
-        z.v32[k]           = ((uint32_t) (sum & 0xffffffffull));
+        uint64_t const sum = ((uint64_t) z.v[k]) + carry;
+        z.v[k]             = ((uint32_t) (sum & 0xffffffffull));
         carry              = sum >> 32;
       }
     }
@@ -242,22 +245,22 @@ static kit_bigint_t kit_bi_mul_uint32(kit_bigint_t const x,
 static kit_bigint_t kit_bi_mul(kit_bigint_t const x,
                                kit_bigint_t const y) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
+  memset(&z, 0, sizeof z);
 
   for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++) {
-    if (x.v32[i] == 0)
+    if (x.v[i] == 0)
       continue;
 
     for (ptrdiff_t j = 0; i + j < KIT_BIGINT_SIZE / 4; j++) {
-      if (y.v32[j] == 0)
+      if (y.v[j] == 0)
         continue;
 
-      uint64_t carry = ((uint64_t) x.v32[i]) * ((uint64_t) y.v32[j]);
+      uint64_t carry = ((uint64_t) x.v[i]) * ((uint64_t) y.v[j]);
 
       for (ptrdiff_t k = i + j; k < KIT_BIGINT_SIZE / 4 && carry != 0;
            k++) {
-        uint64_t const sum = ((uint64_t) z.v32[k]) + carry;
-        z.v32[k]           = ((uint32_t) (sum & 0xffffffffull));
+        uint64_t const sum = ((uint64_t) z.v[k]) + carry;
+        z.v[k]             = ((uint32_t) (sum & 0xffffffffull));
         carry              = sum >> 32;
       }
     }
@@ -288,35 +291,55 @@ static kit_bi_division_t kit_bi_div(kit_bigint_t const x,
   ptrdiff_t       shift  = x_bits - y_bits;
 
   result.remainder = x;
-  result.quotient  = kit_bi_uword(0);
+  result.quotient  = kit_bi_uint32(0);
 
-  y = kit_bi_shl_uword(y, (kit_uword_t) shift);
+  y = kit_bi_shl_uint(y, (uint32_t) shift);
 
   while (shift >= 0) {
     if (kit_bi_compare(result.remainder, y) >= 0) {
       result.remainder = kit_bi_sub(result.remainder, y);
-      result.quotient.v8[shift / 8] |= (1u << (shift % 8));
+      result.quotient.v[shift / 32] |= (1u << (shift % 32));
     }
 
-    y = kit_bi_shr_uword(y, 1);
+    y = kit_bi_shr_uint(y, 1);
     shift--;
   }
 
   return result;
 }
 
+static void kit_bi_serialize(kit_bigint_t const in,
+                             uint8_t *const     out) {
+  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE / 4; i++) {
+    out[i * 4]     = (uint8_t) (in.v[i] & 0xff);
+    out[i * 4 + 1] = (uint8_t) ((in.v[i] >> 8) & 0xff);
+    out[i * 4 + 2] = (uint8_t) ((in.v[i] >> 16) & 0xff);
+    out[i * 4 + 3] = (uint8_t) ((in.v[i] >> 24) & 0xff);
+  }
+}
+
+static kit_bigint_t kit_bi_deserialize(uint8_t const *const in) {
+  kit_bigint_t out;
+  memset(&out, 0, sizeof out);
+
+  for (ptrdiff_t i = 0; i < KIT_BIGINT_SIZE; i++)
+    out.v[i / 4] |= ((uint32_t) in[i]) << (8 * (i % 4));
+
+  return out;
+}
+
 static uint8_t kit_bin_digit(char const hex) {
   return hex == '1' ? 1 : 0;
 }
 
 static kit_bigint_t kit_bi_bin(kit_str_t const bin) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
+  memset(&z, 0, sizeof z);
 
   for (ptrdiff_t i = 0; i < bin.size && i / 8 < KIT_BIGINT_SIZE;
        i++) {
     uint8_t const digit = kit_bin_digit(bin.values[bin.size - i - 1]);
-    z.v8[i / 8] |= digit << (i % 8);
+    z.v[i / 32] |= digit << (i % 32);
   }
 
   return z;
@@ -327,8 +350,8 @@ static uint8_t kit_dec_digit(char const c) {
 }
 
 static kit_bigint_t kit_bi_dec(kit_str_t const dec) {
-  kit_bigint_t z      = kit_bi_uword(0);
-  kit_bigint_t factor = kit_bi_uword(1);
+  kit_bigint_t z      = kit_bi_uint32(0);
+  kit_bigint_t factor = kit_bi_uint32(1);
 
   for (ptrdiff_t i = 0; i < dec.size; i++) {
     uint32_t const digit = kit_dec_digit(
@@ -352,12 +375,12 @@ static uint8_t kit_hex_digit(char const hex) {
 
 static kit_bigint_t kit_bi_hex(kit_str_t const hex) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
+  memset(&z, 0, sizeof z);
 
   for (ptrdiff_t i = 0; i < hex.size && i / 2 < KIT_BIGINT_SIZE;
        i++) {
     uint8_t const digit = kit_hex_digit(hex.values[hex.size - i - 1]);
-    z.v8[i / 2] |= ((i % 2) == 0) ? digit : (digit << 4);
+    z.v[i / 8] |= digit << (8 * (i % 8));
   }
 
   return z;
@@ -381,11 +404,11 @@ static uint8_t kit_base32_digit(char const c) {
 
 static kit_bigint_t kit_bi_base32(kit_str_t const base32) {
   kit_bigint_t z;
-  memset(z.v8, 0, KIT_BIGINT_SIZE);
+  memset(&z, 0, sizeof z);
 
   for (ptrdiff_t i = 0; i < base32.size; i++) {
-    z = kit_bi_shl_uword(z, 5 * i);
-    z.v8[i] |= kit_base32_digit(base32.values[i]);
+    z = kit_bi_shl_uint(z, 5 * i);
+    z.v[0] |= kit_base32_digit(base32.values[i]);
   }
 
   return z;
@@ -413,8 +436,8 @@ static uint8_t kit_base58_digit(char const c) {
 }
 
 static kit_bigint_t kit_bi_base58(kit_str_t const base58) {
-  kit_bigint_t z      = kit_bi_uword(0);
-  kit_bigint_t factor = kit_bi_uword(1);
+  kit_bigint_t z      = kit_bi_uint32(0);
+  kit_bigint_t factor = kit_bi_uint32(1);
 
   for (ptrdiff_t i = 0; i < base58.size; i++) {
     uint32_t const digit = kit_base58_digit(
@@ -448,7 +471,7 @@ static kit_bigint_t kit_bi_base58(kit_str_t const base58) {
 
 #ifndef KIT_DISABLE_SHORT_NAMES
 #  define bigint_t kit_bigint_t
-#  define bi_uword kit_bi_uword
+#  define bi_uint32 kit_bi_uint32
 #  define bi_uint64 kit_bi_uint64
 #  define bi_equal kit_bi_equal
 #  define bi_carry kit_bi_carry
@@ -457,6 +480,8 @@ static kit_bigint_t kit_bi_base58(kit_str_t const base58) {
 #  define bi_sub kit_bi_sub
 #  define bi_mul kit_bi_mul
 #  define bi_div kit_bi_div
+#  define bi_serialize kit_bi_serialize
+#  define bi_deserialize kit_bi_deserialize
 #  define hex_digit kit_hex_digit
 #  define bi_hex kit_bi_hex
 #  define bi_base58 kit_bi_base58
diff --git a/source/test/unittests/bigint.test.c b/source/test/unittests/bigint.test.c
index 66cd87a..5ca49b5 100644
--- a/source/test/unittests/bigint.test.c
+++ b/source/test/unittests/bigint.test.c
@@ -21,7 +21,7 @@ TEST("bigint hex sub") {
 }
 
 TEST("bigint base58") {
-  REQUIRE(bi_equal(BASE58("31"), bi_uword(58 * 2)));
+  REQUIRE(bi_equal(BASE58("31"), bi_uint32(58 * 2)));
 }
 
 TEST("bigint base58 add") {
-- 
cgit v1.2.3