// // By Bob Jenkins, July 2021 // This finds the smallest primitive polynomial for GF(2^n), n <= 64 // g++ -O3 -mpclmul -msse4.1 findprimitive.cpp -o findprimitive // #include #include #include #include #include #include #include #include #define zXor(a,b) _mm_xor_si128(a,b) #define zMult(a,b) _mm_clmulepi64_si128(a,b,0) #define zRight(a,s) _mm_srli_si128(a,s) #define zLeft(a,s) _mm_slli_si128(a,s) static void LogAssert(int condition, const char *fmt, ...) { if (!condition) { printf("error: "); va_list args; va_start(args, fmt); vprintf(fmt, args); va_end(args); printf("\n"); exit(1); } } void Show128(const char* text, __m128i x) { printf("%s : %x %x %x %x\n", text, ((uint32_t *)&x)[3], ((uint32_t *)&x)[2], ((uint32_t *)&x)[1], ((uint32_t *)&x)[0]); } class GF { public: uint64_t m_n; // GF(2^^n) uint64_t m_maxValue; // (2^^n)-1 uint64_t m_nFactors; uint64_t m_factors[64]; // factors of m_maxValues __m128i m_poly; // a (hopefully primitive) polynomial in 2^^m_n __m128i m_invPoly; // 2^^(2*m_n) / m_poly uint64_t m_pow[64]; // 2^^(2^^i) mod m_poly GF(uint64_t n) { m_n = n; m_maxValue = ~0UL; if (n < 64) m_maxValue = (1UL << n) - 1; Factor(); } // // Find the factors of m_maxValue // It may be that 2^^m_maxValue == 1, but it is no good if it is true for factors of m_maxValue too. // Testing 2^^(m_maxValue/factor) for all factors will guard against that. // void Factor() { uint64_t x = m_maxValue; m_nFactors = 0; for (uint64_t i = 3; i*i <= x; i += 2) { if (x == (x / i) * i) { m_factors[m_nFactors++] = i; while (x == (x / i) * i) x /= i; } } if (x > 1) m_factors[m_nFactors++] = x; /* printf("prime factors of (2^%lu)-1 : ", n); for (uint64_t i = 0; i < m_nFactors; i++) printf("%ld ", m_factors[i]); printf("\n"); */ } // Shift bits right. The input it two __m128i, but the output is assumed to be one __m128i. __m128i zShiftBits(__m128i high, __m128i low, uint64_t shift) { // shift the bytes uint64_t byteshift = shift / 8; uint64_t bitshift = shift - (byteshift * 8); if (byteshift >= 16) { low = high; high = _mm_set_epi32(0, 0, 0, 0); byteshift -= 16; } switch (byteshift) { case 0: break; case 1: low = zXor(zRight(low, 1), zLeft(high, 15)); high = zRight(high, 1); break; case 2: low = zXor(zRight(low, 2), zLeft(high, 14)); high = zRight(high, 2); break; case 3: low = zXor(zRight(low, 3), zLeft(high, 13)); high = zRight(high, 3); break; case 4: low = zXor(zRight(low, 4), zLeft(high, 12)); high = zRight(high, 4); break; case 5: low = zXor(zRight(low, 5), zLeft(high, 11)); high = zRight(high, 5); break; case 6: low = zXor(zRight(low, 6), zLeft(high, 10)); high = zRight(high, 6); break; case 7: low = zXor(zRight(low, 7), zLeft(high, 9)); high = zRight(high, 7); break; case 8: low = zXor(zRight(low, 8), zLeft(high, 8)); high = zRight(high, 8); break; case 9: low = zXor(zRight(low, 9), zLeft(high, 7)); high = zRight(high, 9); break; case 10: low = zXor(zRight(low, 10), zLeft(high, 6)); high = zRight(high, 10); break; case 11: low = zXor(zRight(low, 11), zLeft(high, 5)); high = zRight(high, 11); break; case 12: low = zXor(zRight(low, 12), zLeft(high, 4)); high = zRight(high, 12); break; case 13: low = zXor(zRight(low, 13), zLeft(high, 3)); high = zRight(high, 13); break; case 14: low = zXor(zRight(low, 14), zLeft(high, 2)); high = zRight(high, 14); break; case 15: low = zXor(zRight(low, 15), zLeft(high, 1)); high = zRight(high, 15); break; default: LogAssert(false, "unexpected byteshift %lu", byteshift); break; } // shift the bits uint64_t t2 = ((uint64_t *)&high)[0]; uint64_t t1 = ((uint64_t *)&low)[1]; uint64_t t0 = ((uint64_t *)&low)[0]; if (bitshift > 0) { t0 = (t0 >> bitshift) | (t1 << (64 - bitshift)); t1 = (t1 >> bitshift) | (t2 << (64 - bitshift)); } low = _mm_set_epi32((uint32_t)(t1>>32),(uint32_t)t1,(uint32_t)(t0>>32),(uint32_t)t0); return low; } uint64_t Mult(uint64_t x, uint64_t y) { // calculate the full z=x*y, not mod anything // printf("x=%lx, y=%lx\n", x, y); __m128i xx = _mm_set_epi32(0,0,(uint32_t)(x>>32),(uint32_t)x); // Show128("xx", xx); __m128i yy = _mm_set_epi32(0,0,(uint32_t)(y>>32),(uint32_t)y); // Show128("yy", yy); __m128i zz = zMult(xx,yy); // Show128("zz", zz); // Poly is 2^^n ^ (something small). // We want z mod poly. So we want r in (q,r) where q*poly ^ r = z where r < 2^^n. // We precalculated invPoly = 2^^2n/poly, which isn't exact, there was a remainder. // Typically poly is 2^^n ^ (some small stuff) and invPoly is 2^^n ^ (other small stuff). // (invPoly*poly) == 2^^(2n) + stuff, where stuff < 2^^n // (invPoly*poly)>>2n == 1, since stuff>>2n == 0 // (z*invPoly*poly)>>2n == z // (z*invPoly*poly)>>2n == q*poly ^ r, where r < 2^^n // (z*invPoly)>>2n == q (Since poly > r, r/poly == 0.) // ((z*invPoly)>>2n)*poly == q*poly // z ^ ((z*invPoly)>>2n)*poly == r, which is z mod poly, since z ^ q*poly = r. // Since z*invPoly < 2^^3n, this only works up to 42 bits, since 3*43 > 128 bits. __m128i temp0 = _mm_clmulepi64_si128(zz,m_invPoly,0x00); __m128i temp1 = _mm_clmulepi64_si128(zz,m_invPoly,0x10); __m128i temp2 = _mm_clmulepi64_si128(zz,m_invPoly,0x01); __m128i temp3 = _mm_clmulepi64_si128(zz,m_invPoly,0x11); __m128i zi1 = temp3; ((uint64_t *)&zi1)[0] ^= ((uint64_t *)&temp1)[1] ^ ((uint64_t *)&temp2)[1]; __m128i zi0 = temp0; ((uint64_t *)&zi0)[1] ^= ((uint64_t *)&temp1)[0] ^ ((uint64_t *)&temp2)[0]; // Show128("invPoly", m_invPoly); // Show128("zi0", zi0); // Show128("zi1", zi1); __m128i q = zShiftBits(zi1, zi0, 2 * m_n); // Show128("q ", q); __m128i qp = zMult(q, m_poly); // Show128("qp", qp); uint64_t t0 = ((uint64_t *)&qp)[0]; // printf("t0 : %lx\n", t0); t0 ^= ((uint64_t *)&zz)[0]; // printf("r : %lx\n", t0); return t0; } // Fill m_poly with the polynomial and m_invPoly with (2<<(2*m_n))/m_poly. // Also fill m_pow[i] with 2^^i mod m_poly. void UseP(uint64_t p) { // The polynomial is (1<> (64-m_n)); uint64_t t0 = ((m_n == 64) ? 0 : (p << m_n)); // inv is (2<<(2m_n)/m_poly) ^ 2<= m_n; i--) { if (i >= 64) { if (t1 & (1UL << (i-64))) { inv ^= 1UL << (i-m_n); t1 ^= 1UL << (i-64); t1 ^= ((64+m_n-i < 64) ? (p >> (64+m_n-i)) : 0); } } else { if (t0 & (1UL << i)) { inv ^= 1UL << (i-m_n); t0 ^= 1UL << i; t0 ^= p << (i-m_n); } } } LogAssert(t1 == 0, "t1 should be cleared by now"); if (m_n < 64) inv ^= 1UL << m_n; m_invPoly = _mm_set_epi32(0, (m_n==64) ? 1 : 0, (inv>>32), (uint32_t)inv); // 2^^(2*n) / sc_polynomial // Sanity test: // m_inv * m_poly = 2^^(2*n) ^ delta, where delta < 2^^n // So m_inv * m_poly >> n = 2^^n __m128i pi = zMult(m_poly, m_invPoly); // Show128("poly", m_poly); // Show128("invp", m_invPoly); // Show128(" p*i", pi); t1 = ((uint64_t *)&pi)[1]; t0 = ((uint64_t *)&pi)[0]; t0 = (m_n == 64 ? 0 : (t0 >> m_n)); t0 ^= (t1 << (64 - m_n)); // printf("t0 :: %lx\n", t0); LogAssert(t0 == ((m_n == 64) ? 0 : 1UL<