In [ ]:
# this is the probelm we need to solve
import random
import select
import signal
import sympy
import sys


class Unbuffered(object):
    def __init__(self, stream):
        self.stream = stream

    def write(self, data):
        self.stream.write(data)
        self.stream.flush()

    def __getattr__(self, attr):
        return getattr(self.stream, attr)


def random_prime(bits):
    return sympy.nextprime(2 ** bits + random.randint(0, 2 ** bits))


def encrypt(bits, m):
    p = random_prime(bits)
    q = random_prime(bits)
    n = p * q
    assert m < n
    print n
    print m ** 3 % n
    print (m + 1) ** 3 % n


def main():
    signal.alarm(180)
    sys.stdout = Unbuffered(sys.stdout)
    for i in range(1, 10):
        bits = 50 * i
        m = random.randint(0, 4 ** bits)
        encrypt(bits, m)
        rfd, _, _ = select.select([sys.stdin], [], [], 10)
        if rfd:
            try:
                x = int(raw_input())
            except ValueError:
                print "\033[31;1mEnter a number, ok?\033[0m"
                exit()
            if x == m:
                print "\033[32;1mGreat:)\033[0m"
                continue
            else:
                print "\033[31;1mso sad :(\033[0m"
                exit()
        else:
            print "\033[31;1mToo slooooooooooow :(\033[0m"
            exit()

    bits = 512.512
    m = int(open('flag').read().encode('hex'), 16)
    encrypt(bits, m)
    print "\033[32;1mGood Luck!\033[0m"

main()
In [ ]:
from gmpy2 import mpz, divm
from sympy import *
In [ ]:
# suppose m**3 =c and (m+1) ** 3  =d
x, c,d = symbols('x,c,d')
f1 = x**3-c
f2 =  (x+1)**3-d
In [ ]:
# finding gcd of  f1, f2

q,r = div(f1, f2)
print r
q2, r2=div( (x+1)**3-d, r)
# m is the root of linear equation r2
r2= Poly(3*r2, x)
In [ ]:
# Given
n = 2346958776283104186640235260573
c = 1085431292732484513462488498356
d =2179163442056091844263758576965
# so, we can compute  m as following
# The formulat is simply:  -r2.all_coeffs()[0]/-r2.all_coeffs()[1] mod n
m=divm(2*c  + d  - 1, d-c+2, n)
pow(m, 3, n) == c and  pow(m+1, 3, n) == d
In [ ]:
# now doing this automatically
import telnetlib
tel = telnetlib.Telnet("54.64.40.172", 5454)
for i in range(1, 10):    
    n = mpz(tel.read_until('\n').strip())
    c = mpz(tel.read_until('\n').strip())
    d = mpz(tel.read_until('\n').strip())
    m = divm(2*c  + d  - 1, d-c+2, n)
    tel.write('%d\n'%m)
    ans = tel.read_until('\n')
    print i, ans
n = mpz(tel.read_until('\n').strip())
c = mpz(tel.read_until('\n').strip())
d = mpz(tel.read_until('\n').strip())
m = divm(2*c  + d  - 1, d-c+2, n)
ans = tel.read_until('\n')
print ans, n,c,d
print m
In [ ]:
# get our flag
print ("%x"%m).decode('hex')