1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
|
from pwn import *
import json
from Crypto.Util.number import getPrime
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
import hashlib
def B_smooth(total_size, small_factors_size, big_factor_size):
""" Just picking at random should be enough, there is a very small probability
we will pick the same factors twice
"""
smooth_prime = 2
factors = [2]
# large B-sized prime
large_prime = random_prime(1<<(big_factor_size + 1), lbound=1<<(big_factor_size-3))
factors.append(large_prime)
smooth_prime *= large_prime
# all the other small primes
number_small_factors = (total_size - big_factor_size) // small_factors_size
i = 0
for i in range(number_small_factors - 1):
small_prime = random_prime(1<<(small_factors_size + 1), lbound=1<<(small_factors_size-3))
factors.append(small_prime)
smooth_prime *= small_prime
# we try to find the last factor so that the total number is a prime
# (it should be faster than starting from scratch every time)
prime_test = 0
while not is_prime(prime_test):
last_prime = random_prime(1<<(small_factors_size + 1), lbound=1<<(small_factors_size-3))
prime_test = smooth_prime * last_prime + 1
factors.append(last_prime)
smooth_prime = smooth_prime * last_prime + 1
return smooth_prime, factors
def is_pkcs7_padded(message):
padding = message[-message[-1]:]
return all(padding[i] == len(padding) for i in range(0, len(padding)))
def decrypt_flag(shared_secret: int, iv: str, ciphertext: str):
# Derive AES key from shared secret
sha1 = hashlib.sha1()
sha1.update(str(shared_secret).encode('ascii'))
key = sha1.digest()[:16]
# Decrypt flag
ciphertext = bytes.fromhex(ciphertext)
iv = bytes.fromhex(iv)
cipher = AES.new(key, AES.MODE_CBC, iv)
plaintext = cipher.decrypt(ciphertext)
if is_pkcs7_padded(plaintext):
return unpad(plaintext, 16).decode('ascii')
else:
return plaintext.decode('ascii')
io = remote('socket.cryptohack.org', 13378)
alice = json.loads(io.recvline().strip().decode().split("e: ")[1])
p = int(alice['p'], 16)
g = int(alice['g'], 16)
A = int(alice['A'], 16)
bob = json.loads(io.recvline().strip().decode().split("b: ")[1])
B = int(bob['B'], 16)
ct = json.loads(io.recvline().strip().decode().split("e: ")[1])
iv = ct['iv']
enc_flag = ct['encrypted']
send_bob = dict()
gen = B_smooth(2000, 15, 40)
n = gen[0]
# factors = gen[1]
# print(factors)
send_bob['p'] = hex(n)
send_bob['g'] = hex(g)
send_bob['A'] = hex(pow(g, getPrime(1024), n))
io.sendline(json.dumps(send_bob).encode())
bob_secret = json.loads(io.recvline().strip().decode().split("u: ")[1])
bob_encrypted = json.loads(io.recvline().strip().decode().split("u: ")[1])
bob_encrypted_iv = bob_encrypted['iv']
bob_encrypted_sth = bob_encrypted['encrypted']
bob_secret = int(bob_secret['B'], 16)
F = Zmod(n)
g = F(g)
bob_secret = F(bob_secret)
b = discrete_log(bob_secret, g)
shared_secret = pow(A, b, p)
print(decrypt_flag(shared_secret, iv, enc_flag))
|