The challenge uses a LCG to generate plaintext, key and IV. To receive the flag we must present the AES-CBC encrypted plaintext given the key and IV, only we don’t know the key. The LCG resets a new state every 16 states (from the refresh function). Observing the code, we can learn that:

  • The last 8 bytes of the plaintext and the first 8 bytes of the key are derived from some 16 consecutive states
  • The last 8 bytes of the key and the first 8 bytes of the IV are derived from some 16 consecutive states.

We must recover the state from the last 8 byte of the plaintext and the first 8 byte from the IV. From that, we can recover the first 8 bytes of the key (by generating the following 8 states after the last 8 states of the plaintext), and the last 8 bytes of the key (by generating the preceding 8 states before the first 8 states of the IV).

This is the attack detailed in this paper Reconstructing Truncated Integer Variables Satisfying Linear Congruences, which goes into detail the math involves. There is some implementation online for this attack, I use this implementation, which is nice to use as the author left comments on the arguments used.

Another good reference about lattice constructions is this link

With these figured out, generating the key is trivial. Generating the following 8 states using the plaintext is simply just take the last state of the PRNG generating the plaintext, then output the next 8 bytes. The following 8 bytes of the key is a bit trickier, but still very straightforward. We know the first state used for generating the IV, denoted by $s_i$. The preceding states follow the relation of:

$$ a * s_{i - 1} + b = s_i \mod m $$

which is equivalent to:

$$ s_{i - 1} = (s_i - b) * a ^ {-1} \mod m $$

Then the task is to generate the preceding 8 states and we should obtain the last 8 bytes of the flag.

Sage Implementation:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from Crypto.Util.number import bytes_to_long
from pwn import * 
import json 
from Crypto.Cipher import AES
from os import urandom

class LCG:
    def __init__(self, a, b, m, seed):
        self.a = a
        self.b = b
        self.m = m
        self.state = seed
        self.counter = 0

    def refresh(self):
        self.counter = 0
        self.state = bytes_to_long(urandom(6))

    def next_state(self):
        self.state = (self.a * self.state + self.b) % self.m

    def get_random_bits(self, k):
        if self.counter == 16:
            self.refresh()
        self.counter += 1
        self.next_state()
        return self.state >> (48 - k)

    def get_random_bytes(self, number):
        bytes_sequence = b''
        for i in range(number):
            bytes_sequence += bytes([self.get_random_bits(8)])
        return bytes_sequence

a = 0x1337deadbeef
b = 0xb
m = 2**48

io = remote('socket.cryptohack.org', int(13396))
io.recvline()

def attack(y, k, s, m, a, c):
    """
    Recovers the states associated with the outputs from a truncated linear congruential generator.
    More information: Frieze, A. et al., "Reconstructing Truncated Integer Variables Satisfying Linear Congruences"
    :param y: the sequential output values obtained from the truncated LCG (the states truncated to s most significant bits)
    :param k: the bit length of the states
    :param s: the bit length of the outputs
    :param m: the modulus of the LCG
    :param a: the multiplier of the LCG
    :param c: the increment of the LCG
    :return: a list containing the states associated with the provided outputs
    """
    diff_bit_length = k - s

    # Preparing for the lattice reduction.
    delta = c % m
    y = vector(ZZ, y)
    for i in range(len(y)):
        # Shift output value to the MSBs and remove the increment.
        y[i] = (y[i] << diff_bit_length) - delta
        delta = (a * delta + c) % m

    # This lattice only works for increment = 0.
    B = matrix(ZZ, len(y), len(y))
    B[0, 0] = m
    for i in range(1, len(y)):
        B[i, 0] = a ** i
        B[i, i] = -1

    B = B.LLL()

    # Finding the target value to solve the equation for the states.
    b = B * y
    for i in range(len(b)):
        b[i] = round(QQ(b[i]) / m) * m - b[i]

    # Recovering the states
    delta = c % m
    x = list(B.solve_right(b))
    for i, state in enumerate(x):
        # Adding the MSBs and the increment back again.
        x[i] = int(y[i] + state + delta)
        delta = (a * delta + c) % m

    return x

to_send = {'option': 'get_a_challenge'}
io.sendline(json.dumps(to_send).encode())
response = json.loads(io.recvline().decode())

plaintext = bytes.fromhex(response['plaintext'])
iv = bytes.fromhex(response['IV'])

# Collect the PRNG output related to the generation of the key
rng_plaintext = [i for i in plaintext[8:]]
rng_iv = [i for i in iv[:8]]
key = b''

# Recover the first 8 states of the PRNG, when generate the plaintext
rng_plaintext_states = attack(rng_plaintext, 48, 8, m, a, b)
lcg = LCG(a, b, m, rng_plaintext_states[-1])

# The first 8 bytes of the key will be the last 8 states of the PRNG used for
# generating the plaintext, where the plaintext has the first 8 states
for i in range(8):
    key += bytes([lcg.get_random_bits(8)])

# Recover the last 8 states of the PRNG, when generate the IV
# The first 8 bytes of the key will be the first 8 states of the PRNG used for
# generating the IV, where the IV is the following 8 states
rng_iv_states = attack(rng_iv, 48, 8, m, a, b)
state = rng_iv_states[0]
temp = b''
# Generating the output from the previous states, which can be recovered from the known state s
# Denote the previous state as x, then ax + b = s, thus x = (s - b) * a ^ -1
for i in range(8):
    prev_state = ((state - b) * pow(a, -1, m)) % m 
    temp += bytes([prev_state >> 40])
    state = prev_state

# Append the key to the temp value generated
key = key + temp[::-1]

cipher = AES.new(key, AES.MODE_CBC, iv)
ciphertext = cipher.encrypt(plaintext)

to_send = {'option': 'validate', 'ciphertext': ciphertext.hex()}
io.sendline(json.dumps(to_send).encode())
response = json.loads(io.recvline().decode())
print(response['msg'])