HITCON CTF 2019 Quals - Very Simple Haskell (crypto)

The challenge is based on the following clean Haskell code:

import Data.Char
import System.IO

n :: Integer
n = 134896036104102133446208954973118530800743044711419303630456535295204304771800100892609593430702833309387082353959992161865438523195671760946142657809228938824313865760630832980160727407084204864544706387890655083179518455155520501821681606874346463698215916627632418223019328444607858743434475109717014763667

k :: Int
k = 131

primes :: [Integer]
primes = take k $ sieve (2 : [3, 5..])
  where
    sieve (p:xs) = p : sieve [x|x <- xs, x `mod` p > 0]

stringToInteger :: String -> Integer
stringToInteger str = foldl (\x y -> (toInteger $ ord y) + x*256) 0 str

integerToString :: Integer -> String
integerToString num = f num ""
    where
        f 0 str = str
        f num str = f (div num 256) $ (:) (chr $ fromIntegral $ num `mod` 256) str

numToBits :: Integer -> [Int]
numToBits num = f num []
    where 
        f 0 arr = arr
        f x arr = f (div x 2) ((fromInteger $ x `mod` 2) : arr)

extendBits :: Int -> [Int] -> [Int]
extendBits blockLen arr
    | len == 0 = arr
    | len > 0 = (replicate (blockLen-len) 0) ++ arr
    where len = (length arr) `mod` blockLen

calc :: Integer -> [Int] -> Integer
calc num [] = num
calc num arr = calc result restArr
    where
        num2 = num*num `mod` n
        (block, restArr) = splitAt k arr
        zipped = zipWith (\x y -> ((fromIntegral x)*y) `mod` n) block primes  
        mul = product $ filter (/=0) zipped
        result = num2*mul `mod` n

magic :: String -> String
magic input = result
    where 
        num = stringToInteger input
        bits = numToBits num
        extended = reverse $ extendBits 8 bits
        oriLen = length extended
        extendedBits = extendBits k extended
        oriLenBits = numToBits $ fromIntegral oriLen
        extendedOriLenBits = extendBits k oriLenBits
        finalBits = extendedOriLenBits ++ extendedBits
        result = show $ calc 1 (reverse finalBits)

main = do
    flag <- readFile "flag"
    putStrLn.show $ length flag
    putStrLn $ magic ("the flag is hitcon{" ++ flag ++ "}")

The challenge is rather straightforward. First, the flag is converted to a string of bits, padded to 8 and 131 bits, reversed, appended with length of padded string, etc.

The final step is the $calc(num, arr)$ function. It splits the input string into 131-bit chunks, computes

$$ mul = p_1^{b_1} p_2^{b_2} \cdots p_{131}^{b_{131}} \mod{n}, $$

where $p_i$ is the $i$-th prime, $b_i$ is the $i$-th bit of the string. It updates $num$ with $num^2 \times mul \mod{n}$ and proceeds with the next blocks.

As a result, the final output is a product of primes raised to particular powers depending on bits of the input string and reduced modulo $n$. More precisely, the prime $p_i$ has power $4b_i + 2b_{i+131} + 3b_{i+262}$. After a close look, we can see that the only unknown part of the $calc$ input is formed by 6 bytes of the flag and is placed fully in the second block.

In [1]:
from __future__ import print_function, division
from sage.all import *
from Crypto.Util.number import *
from sock import Sock
class Stop(Exception): _render_traceback_ = lambda self: None

def s2b(s):
    ret = []
    for c in s:
        ret.append(bin(ord(c))[2:].zfill(8))
    return "".join(ret)

def b2s(b):
    ret = []
    for pos in range(0, len(b), 8):
        ret.append(chr(int(b[pos:pos + 8], 2)))
    return "".join(ret)

n = 134896036104102133446208954973118530800743044711419303630456535295204304771800100892609593430702833309387082353959992161865438523195671760946142657809228938824313865760630832980160727407084204864544706387890655083179518455155520501821681606874346463698215916627632418223019328444607858743434475109717014763667
s = "the flag is hitcon{abcdef}"
# simple bit transformations
extended = map(int, s2b(s))
extended[-7*8:-1*8] = [None] * 48
extended = extended[::-1]
extendedBits = [0] * (262 - len(extended)) + extended
oriLen = 26 * 8
oriLenBits = Integer(oriLen).bits()[::-1]
extendedOriLenBits = [0] * (131 - len(oriLenBits)) + oriLenBits
out = extendedOriLenBits + extendedBits

s = out[::-1]
v = 1
prs = []
prprod = 1
for i in xrange(0, len(s), 131): 
    mul = 1
    for bit, pr in zip(s[i:i+131], primes(1000)):
        if bit is None:
            prprod *= pr
            prs.append(pr)
        if bit == 1:
            mul *= pr
    v = v * v * mul % n

We can easily cancel the other values and leave only the unknown prime powers. Note that we have squares of primes, and computing a square root modulo a composite is a hard problem.

In [2]:
chal = 84329776255618646348016649734028295037597157542985867506958273359305624184282146866144159754298613694885173220275408231387000884549683819822991588176788392625802461171856762214917805903544785532328453620624644896107723229373581460638987146506975123149045044762903664396325969329482406959546962473688947985096
diff = int(chal * inverse_mod(v, n)) % n
print("flagprod", diff)
print("factor", factor(diff))
flagprod 3406218222930966554172275269328526576001581947668541896752967656582693660956578801
factor 83^2 * 89^2 * 97^2 * 103^2 * 107^2 * 127^2 * 173^2 * 197^2 * 223^2 * 227^2 * 229^2 * 233^2 * 239^2 * 257^2 * 283^2 * 311^2 * 337^2 * 347^2

Luckily, we don't have to solve it: the actual full product is not reduced, because it is small enough. It is easy to recover the flag now:

In [3]:
bs = []
for p in prs:
    if diff % p**2:
        bs.append(0)
    else:
        bs.append(1)
bs = "".join(map(str, bs))
print("hitcon{%s}" % b2s(bs))
hitcon{[email protected]!>A#}