Googling the name of the challenge should point us to a paper, detailing how Miller-Rabin tests are misused in practice - some strong pseudoprimes can still pass the Miller-Rabin test.

Searching for the implementation of this challenge will lead us to this Github repo. Running the script should lead to a pseudoprime n = p1 * p2 * p3. Sending the base as one of the primes and the prime as n will return the flag.

I have my own implementation, which I think is correct but I am having some issues with picking out the appropriate residues to generate a good residue from CRT for p1.

Sage Implementation:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from Crypto.Util.number import getPrime
from tqdm import tqdm
import itertools
# Algorithm from https://eprint.iacr.org/2018/749.pdf
# Check out the example below for illustration of how the algorithm works

def generate_basis(n):
    basis = [True] * n
    for i in range(3, int(n**0.5)+1, 2):
        if basis[i]:
            basis[i*i::2*i] = [False]*((n-i*i-1)//(2*i)+1)
    return [2] + [i for i in range(3, n, 2) if basis[i]]

def miller_rabin(n, b):
    """
    Miller Rabin test testing over all
    prime basis < b
    """
    basis = generate_basis(b)
    if n == 2 or n == 3:
        return True

    if n % 2 == 0:
        return False

    r, s = 0, n - 1
    while s % 2 == 0:
        r += 1
        s //= 2
    for b in basis:
        x = pow(b, s, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(r - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True

# Since not all values of residues yield a solution to CRT (can verify by trying to implement this on your own),
# we should try to generate all the possible combinations of residues over all moduli
# erem, emod is some fixed condition we have, more specifically x % emod = erem
def crt_backtrack(options, mods, erem, emod):
    if not options:
        return (erem, emod)
    for o in options[0]:
        try:
            c = crt([erem, o], [emod, mods[0]])
            res = crt_backtrack(options[1:], mods[1:], c, lcm([mods[0], emod]))
            if res:
                return res 
        except ValueError: 
            pass
def legendre(a, p):
    """
    Returns the Legendre symbol (a/p)
    """
    return pow(a, (p - 1) // 2, p)

# Prime bases, or denoted by a in the paper that we have
bases = 64
prime_bases = generate_basis(bases)

# Generate S_a, S_b
S_b = {}
potential_primes = generate_basis(50000)

for a in prime_bases:
    # Storing primes satisfying the constraint
    S_a = set()

    # Skip the first number as legendre(2, 2) = 1 = -1
    for p in potential_primes[1:]:
        # Legendre symbol is -1
        if legendre(a, p) == p - 1:
            S_a.add(p % (4 * a))
    
    S_b[a] = S_a

# Values of k, pick some prime until it works, will have a lot of errors but just retry until it spits out something
k2 = getPrime(10)
k3 = getPrime(10)

k_values = [1, k2, k3]

# Storing moduli for CRT backtracking later
moduli = []

# Storing residues for CRT backtracking later
residues = []
for a in prime_bases:
    S_a = S_b[a]
    # Sets generated by each k for the given value of a
    k_sets = [set() for i in range(3)]

    # For each value of k, add the respective set to k_sets
    for i in range(3):
        # Each of the three k values
        k = k_values[i]
        k_set = set()
        for s in S_a:
            inv_k = pow(k, -1, 4 * a)
            to_add = (inv_k * (s + k - 1)) % (4 * a)
            k_set.add(to_add)
    
        k_sets[i] = k_set
    
    # Intersection of the residue sets generated by each of the k values
    residues_set = k_sets[0].intersection(k_sets[1], k_sets[2])

    residues.append(residues_set)
    
    # Add the modulus of 4 * a
    moduli.append(4 * a)


# Iterate through every choice of residues
for pos_residues in itertools.product(*residues):
    # The requirement is missing minus signs (can observe the examples below the algorithm's theory)
    # then try to do pow(k2, -1, k3) of the example value they have
    # In here, we are trying to generate the values for the constant erem, emod (the additional constraints are fixed)
    temp_residues = list(pos_residues) + [pow(-k2, -1, k3), pow(-k3, -1, k2)]
    temp_moduli = moduli + [k2, k3]

    # May not have a CRT solution, move forward
    try: 
        p1, mod = crt(temp_residues, temp_moduli)
    except Exception: 
        continue
    # Skipping for the impossibly long times of adding mod to p1, target bits is 600, 
    # hence each of the primes should be more than 200 bits
     
    p1 += (2 ** 73) * mod

    for i in tqdm(range(100000)):
        # Arnault equation of the other primes, p_i = k_i * (p_1 - 1) + 1
        p2 = k2 * (p1 - 1) + 1
        p3 = k3 * (p1 - 1) + 1

        # Check if p1, p2, p3 is prime
        if p1.is_prime() and p2.is_prime() and p3.is_prime():
            n = p1 * p2 * p3 
            print(n.bit_length())
            print(miller_rabin(n, bases))
            if miller_rabin(n, bases):
                break 

        # Adding modulo to p1 until we get the solution
        p1 += mod 

    print(p1, p2, p3)
# print(p1, p2, p3)
# io = remote('socket.cryptohack.org', 13385)
# to_send = dict()
# to_send['prime'] = n 
# to_send['base'] = p1

# io.sendline(json.dumps(to_send).encode())
# io.interactive()