This is a challenge about length extension attack on hashes like MD5 and SHA1 that uses the Merkle-Damgard construction. More on that can be found in this link, and a good Youtube video about this topic, if you can understand Vietnamese, is this series by CyberJutsu.

The following is written under the assumption that you have some experience with the attack. In MD5 hash extension attack, it is often the case that we are allowed to extend the secret with some arbitrary data that we decide. However, in this context, the appended data has to go through the bxor function, which will involve xor of the data we have no information on. Hence, the state after xor is unknown if we do not craft the payload we send to the server properly.

However, we do know the prefix crypto{ and } of the flag. And as the bxor function is called by bxor(data, cycle(FLAG)), the flag’s content is cycled through. Hence, we know 8 consecutive bytes that is xor-ed, }crypto{. The idea is now to craft some payload such that we can have the 8 bytes that we know is in the padding of MD5 (where we know the content due to the fact that it has to follow the algorithm, and we know the length of the message passed to the MD5 hash function).

Therefore, we will craft a block where the length (in bytes) is congruent to 55 (mod 64), so that the ending is of the form \x80 + LENGTH OF THE MESSAGE, and also the 8 bytes }crypto{ is the next 8 bytes in the cycle(FLAG) function. In mathematical terms, we have to find the number of blocks $x$ such that $46x - 1 = 55 \mod 64$. Solving this yields $x = 4$, which means that we will use a payload of size $46 \times 4 - 1$.

Unfortunately, we always have to append \x80, followed by 8 bytes (the length of the message). In total there is at least 9 bytes for the padding portion. Hence, we have to guess out the last byte, which is the first character after {. To do this, we can extend the secret block with a block that only contains the padding. The target for us is to guess out the character after {, and we know it is the correct character when the hash matches with the locally computed hash from extending the previously requested hash.

With the payload 64-byte aligned, guessing out the other characters becomes easy. We already know the state of all the blocks except the last one, hence we can locally compute the hash generated from extending the known state with the guess.

In my implementation, I try to implement MD5 on my own, and it certainly helped in solving this challenge. Apparently other solutions involves the use of hashpumpy, to make the solution shorter.

Python Implementation:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import numpy as np
from pwn import *
import json 

# Implementation from https://www.youtube.com/watch?v=HWpaz5XsECc
# choose this because this generates the same hash as Python hashlib MD5 implementation

# ----- BEGIN IMPLEMENTATION OF MD5 -----
shift = [7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22,
         5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20,
         4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23,
         6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21]

sines = np.abs(np.sin(np.arange(64) + 1))  # "nothing up my sleeve" randomness
sine_randomness = [int(x) for x in np.floor(2 ** 32 * sines)]

md5_block_size = 64
md5_digest_size = 16

def left_rotate(x: int, y: int) -> int:
    """
    Rotate the bits of x by y places, as if x and y are 32-bit unsigned integers.
    >>> left_rotate(0b11111111000000001010101011001100, 1) == \
                    0b11111110000000010101010110011001
    True
    """
    return ((x << (y & 31)) | ((x & 0xffffffff) >> (32 - (y & 31)))) & 0xffffffff


def bit_not(x: int) -> int:
    """
    The bitwise complement of x if x were represented as a 32-bit unsigned integer.
    >>> bit_not(0b11111111000000001010101011001100) == \
                0b00000000111111110101010100110011
    True
    """
    return 4294967295 - x


"""
Mixing functions. 
Each of F, G, H, I has the following property.
Given: all the bits of all the inputs are independent and unbiased,
Then: the bits of the output are also independent and unbiased.
"""

def F(b: int, c: int, d: int) -> int:
    return d ^ (b & (c ^ d))

def G(b: int, c: int, d: int) -> int:
    return c ^ (d & (b ^ c))

def H(b: int, c: int, d: int) -> int:
    return b ^ c ^ d

def I(b: int, c: int, d: int) -> int:
    return c ^ (b | bit_not(d))

mixer_for_step = [F for _ in range(16)] + [G for _ in range(16)] + [H for _ in range(16)] + [I for _ in range(16)]

"""
These are all permutations of [0, ..., 15].
"""

round_1_perm = [i for i in range(16)]  # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
round_2_perm = [(5 * i + 1) % 16 for i in range(16)]  # [1, 6, 11, 0, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12]
round_3_perm = [(3 * i + 5) % 16 for i in range(16)]  # [5, 8, 11, 14, 1, 4, 7, 10, 13, 0, 3, 6, 9, 12, 15, 2]
round_4_perm = [(7 * i) % 16 for i in range(16)]  # [0, 7, 14, 5, 12, 3, 10, 1, 8, 15, 6, 13, 4, 11, 2, 9]

msg_idx_for_step = round_1_perm + round_2_perm + round_3_perm + round_4_perm

class MD5:
    def __init__(self):
        self.length: int = 0
        self.state: tuple[int, int, int, int] = (0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476)
        self.message: bytes = None 
    
    def digest(self) -> bytes:
        return b''.join(x.to_bytes(length=4, byteorder='little') for x in self.state)

    def hex_digest(self) -> str:
        return self.digest().hex()
    
    # Message padding, following what is in Wikipedia MD5 algorithm
    # https://en.wikipedia.org/wiki/MD5
    def pad(self, message):
        self.message = message 
        self.length = (len(message) * 8) % (2 ** 64)

        pad = message + b'\x80'
        while len(pad) % 64 != 56:
            pad += b'\x00'
        
        pad += self.length.to_bytes(length=8, byteorder='little')

        return pad 

    # The compression function, takes in a chunk of 64 bytes and output the hash from 
    # the block and current state
    def compress(self, msg_chunk: bytearray) -> None:
        assert len(msg_chunk) == md5_block_size  # 64 bytes, 512 bits
        msg_ints = [int.from_bytes(msg_chunk[i:i + 4], byteorder='little') for i in range(0, md5_block_size, 4)]
        assert len(msg_ints) == 16

        a, b, c, d = self.state

        for i in range(md5_block_size):
            bit_mixer = mixer_for_step[i]
            msg_idx = msg_idx_for_step[i]
            a = (a + bit_mixer(b, c, d) + msg_ints[msg_idx] + sine_randomness[i]) % (2 ** 32)
            a = left_rotate(a, shift[i])
            a = (a + b) % (2 ** 32)
            a, b, c, d = d, a, b, c
        
        self.state = (
            (self.state[0] + a) % (2 ** 32),
            (self.state[1] + b) % (2 ** 32),
            (self.state[2] + c) % (2 ** 32),
            (self.state[3] + d) % (2 ** 32),
        )

    # Load the state of the previous blocks from the given hash
    def load_state(self, hash):
        registers = [0, 0, 0, 0]

        for i in range(len(registers)):
            registers[i] = int.from_bytes(bytes.fromhex(hash[8 * i: 8 * (i + 1)]), byteorder='little')
        self.state = tuple(registers)

# ----- END IMPLEMENTATION OF MD5 -----

# Initialize the MD5 hash generation and the "dummy" flag
m = MD5()

# Dummy flag to use for the dummy padding, only purpose is to extract the extension 
# in particular, the \x80 byte and the little endian representation of the length of 
# the message in bits
FLAG = b'crypto{??????????????????????????????????????}'

# All ASCII characters
chars = [chr(i) for i in range(32, 127)]

# Generate what the flag should look like with the underneath payload
dummy = FLAG * 3 + FLAG[:-1]

# Any bytes of length len(dummy) works, I choose the \x00 byte
payload = b'\x00' * len(dummy)

# Connect to server
io = remote('socket.cryptohack.org', 13407)
io.recvline()
to_send = {'option': 'message', 'data': payload.hex()}
io.sendline(json.dumps(to_send).encode())

# The secret block where we have no control over the text in
secret_blk = json.loads(io.recvline().decode())['hash']

# Get extension (the padding of \x80, followed by some \x00 and 
# the length of the message) for hash-extension attack
dummy_pad = m.pad(dummy)
extension = dummy_pad[-9:]
payload += xor(extension[:8], b'}' + FLAG[:7])

# Guess out first character after the {, for the payload to be 64-byte aligned
# The aim is to generate a block where the extension is of the form \x80 + length of the message
dummy_pad_extend = m.pad(dummy_pad)

# Load the state of the secret block to compare
m.load_state(secret_blk)

# The hash is the state update on the extension block
m.compress(dummy_pad_extend[64 * 3:])
target = m.hex_digest()

FLAG = 'crypto{'

# Guess out the character after the { character
for char in chars: 
    guess = payload + char.encode()
    to_send = {'option': 'message', 'data': guess.hex()}
    io.sendline(json.dumps(to_send).encode())
    hash = json.loads(io.recvline().decode())['hash']
    if hash == target:
        payload = guess 
        FLAG += char
        break

print(FLAG)

# Now we guess the character as we know the state before the last block, 
# the eventual hash can be requested from the server. Hence the task is 
# just to calculate the hash on our end and see which character produces the same hash
for i in range(38):
    payload += b'\x00' # \x00 byte here as we want to preserve the character of the flag
    to_send = {'option': 'message', 'data': payload.hex()}
    io.sendline(json.dumps(to_send).encode())
    target = json.loads(io.recvline().decode())['hash']

    # Guessing out which character produces the same hash
    for char in chars:  
        m = MD5()
        m.load_state(secret_blk)
        # guessing the characters in the final block
        guess = m.pad(dummy_pad + FLAG[8:].encode() + char.encode()) 

        # Again, same thing as above, the hash is the state update of the compression
        m.compress(guess[64 * 3:])
        guess_hash = m.hex_digest()
        if guess_hash == target:
            FLAG += char 
            print(FLAG)
            break