0CTF/TCTF 2020 Quals - emmm (crypto)

In this challenge, we are given a simple block cipher (though not fully invertible), based on 3 modular multiplications.

In [1]:
# Sage mode
P = 247359019496198933 # 2**57.78
C = 223805275076627807 # 2**57.64
M = 2**60
K0 = random.randint(1, P-1)
K1 = random.randint(1, P-1)

# not a bijection? can be adjusted but I'm lazy
def encrypt_block(x):
    tmp = x * K0 % P
    tmp = tmp * C % M
    tmp = tmp * K1 % P
    return tmp

We are also given $2^{24}$ random plaintext-ciphertext pairs and encrypted flag.

Let's build linearized relation between a plaintext and the ciphertext, by introducing variables for modular reductions.

Consider an encryption $(x, y)$. Let $$ \begin{align} t_1 &= \lfloor K_0x / P \rfloor < x, \\ t_2 &= \lfloor (K_0x\mod{P}) / M \rfloor < PC/M, \\ t_3 &= \lfloor K_1^{-1}y / P \rfloor + \epsilon < y + \epsilon, \\ \end{align} $$ where $0 \le \epsilon \le \lfloor M / P \rfloor = 4$ is such that $yK_1^{-1} -t_3P$ matches the plaintext side second step encryption (tmp = tmp * C % M). This little difference happens since the value after the second step can be large up to $M=2^{60}$ and then it is reduced modulo $P\approx 2^{57.64}$, so a couple bits of information are lost.

Then, following two steps of encryption of $x$ and one step of decryption of $y$ we get: $$ \begin{align} & (K_0x - t_1P)C - t_2M = K_1^{-1}y - t_3P, \\ \Rightarrow & xC\cdot K_0 - y\cdot K_1^{-1} - PC\cdot t_1 - M\cdot t_2 + P\cdot t_3 = 0, \end{align} $$ where we have unknowns $$ \begin{align} 0 \le~ & K_0 < P,\\ 0 \le~ & K_1^{-1} < P,\\ 0 \le~ & t_1 < x,\\ 0 \le~ & t_2 < PC/M,\\ 0 \le~ & t_3 < y+\epsilon. \end{align} $$

Note that $x,y,t_1,t_2,t_3$ are different for each known data pair.

We can now use LLL to solve this constraint system. As an example, consider the following lattice for $n=3$ data pairs (rows as vectors): $$ \begin{matrix} &~~~~ eq_0 ~~~~~~ eq_1 ~~~~~~ eq_2 ~~~ K_0 ~~~ K_1 ~~~ . ~~~~ t_{1,i} ~~~ . ~~~~ . ~~~~ t_{2,i} ~~~ . ~~~~ . ~~~~ t_{3,i} ~~~ . \hfill \\ \begin{matrix} K_0 \\ K_1 \\ . \\ t_{1,i} \\ . \\ . \\ t_{2,i} \\ . \\ . \\ t_{3,i} \\ . \end{matrix} \hspace{-1em}& \begin{pmatrix} x_0C & x_1C & x_2C & 1 & & & & & & & & & & \\ y_0 & y_1 & y_2 & & 1 & & & & & & & & & \\ CP & & & & & 1 & & & & & & & & \\ & CP & & & & & 1 & & & & & & & \\ & & CP & & & & & 1 & & & & & & \\ M & & & & & & & & 1 & & & & & \\ & M & & & & & & & & 1 & & & & \\ & & M & & & & & & & & 1 & & & \\ P & & & & & & & & & & & 1 & & \\ & P & & & & & & & & & & & 1 & \\ & & P & & & & & & & & & & & 1 \\ \end{pmatrix} \end{matrix} $$

We are looking for a linear combination of rows that makes the first $n$ entries zero, and the others to respect our bounds. We can achieve this by scaling the coordinates (columns): first $n$ columns should be multiplied by a very large number (forcing LLL to make it zero), the other columns should be multiplied inversely to their bounds. After applying the LLL, we need to scale back.

In [2]:
f = open("res")
data = []
for line in f:
    try:
        pt, ct = map(int, line.split())
    except:
        break
    data.append((pt, ct))

Note that the bounds of $t_{1,i},t_{2,i}$ depend on the actual plaintexts. We are thus interested in smallest plaintexts and ciphertexts. As we shall see, 20 smallest pairs are enough!

In [3]:
data.sort(key=lambda a: a[0]**2 + a[1]**2)
In [4]:
n = 20
pairs = data[:20]
In [5]:
m = matrix(QQ, 2 + 3*n, 2 + 4*n)
m[:,n:] = identity_matrix(2+3*n)
for i, (x, y) in enumerate(pairs):
    m[0,i] = C*x
    m[1,i] = y
    m[0*n+2+i,i] = C*P
    m[1*n+2+i,i] = M
    m[2*n+2+i,i] = P
    
bounds  = [1] * n + [P] * 2
bounds += [pt for pt, ct in pairs]
bounds += [P*C / M] * n
bounds += [ct for pt, ct in pairs]
assert len(bounds) == m.ncols()
In [6]:
# scale
for i, b in enumerate(bounds):
    m.set_column(i, m.column(i)/b)
# LLL
m = m.LLL()
for i, b in enumerate(bounds):
    m.set_column(i, m.column(i)*b)
In [7]:
for irow, row in enumerate(m):
    k0, negk1i = row[n:n+2]
    if gcd(negk1i, P) == 1:
        k1 = inverse_mod(-int(negk1i), P)
        for x, y in pairs:
            tmp = x * k0 % P
            tmp = tmp * C % M
            tmp = tmp * k1 % P
            if tmp != y:
                break
        else:
            k0 %= P
            k1 %= P
            print("Row %d: key recovered: %x %x" % (irow, k0, k1))
            break
Row 11: key recovered: 1df19a439748567 29ad0f3aac513b9

That was fast! Now let's decrypt the flag (recall that a couple of bits is missing, so we have to check a few candidates per block). Also the first reduction modulo $P$ destroys a few bits too.

In [8]:
k0i = inverse_mod(k0, P)
k1i = inverse_mod(k1, P)
Ci = inverse_mod(C, M)

def decrypt_block(y):  
    for t in range(5):
        v = y * k1i % P
        v += t*P  
        if v >= M: continue
        v = v * Ci % M
        if v >= P: continue
        v = v * k0i % P
        while v < 2**64:
            #assert encrypt_block(v) == y
            yield v
            v += P
In [10]:
from struct import unpack, pack
ct = open("res").read()[-200:].strip().split()[-1]
ct = bytes.fromhex(ct)
fmt = '%dQ' % (len(ct)/8)
ct = unpack(fmt, ct)

flag = b""
for block in ct:
    for dec in decrypt_block(block):
        dec = pack("<Q", dec)
        if all(0x20 <= v < 127 for v in dec):
            print(dec)
            flag += dec

print(b"flag{" + flag + b"}")
b'_p4droNe'
b'_a5k3d_m'
b'3_7o_br1'
b'ng_tHis_'
b'foR_thE_'
b'5ign0Ra.'
b'flag{_p4droNe_a5k3d_m3_7o_br1ng_tHis_foR_thE_5ign0Ra.}'