Invalid curve attack on elliptic curve.

Information

  • category : crypto
  • points : 338

Description

Exchange your keys

nc 134.175.225.42 8848

1 file: task.py

Writeup

I solved this challenge together with The_Lillo and 0ssigeno.

We are given a task.py which is the script running on the server:

import os,random,sys,string
from hashlib import sha256
import SocketServer
import signal
from FLAG import flag
from gmpy2 import invert
from Crypto.Util.number import bytes_to_long, long_to_bytes

q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd
a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
zero = (0,0)

def add(p1,p2):
	if p1 == zero:
		return p2
	if p2 == zero:
		return p1
	(p1x,p1y),(p2x,p2y) = p1,p2
	if p1x == p2x and (p1y != p2y or p1y == 0):
	    return zero
	if p1x == p2x:
	    tmp = (3 * p1x * p1x + a) * invert(2 * p1y , q) % q
	else:
	    tmp = (p2y - p1y) * invert(p2x - p1x , q) % q
	x = (tmp * tmp - p1x - p2x) % q
	y = (tmp * (p1x - x) - p1y) % q
	return (int(x),int(y))


def mul(n,p):
	r = zero
	tmp = p
	while 0 < n:
	    if n & 1 == 1:
	        r = add(r,tmp)
	    n, tmp = n >> 1, add(tmp,tmp)
	return r

def pointToString(p):
	return "(" + str(p[0]) + "," + str(p[1]) + ")"

Px = 0xb55c08d92cd878a3ad444a3627a52764f5a402f4a86ef700271cb17edfa739ca
Py = 0x49ee01169c130f25853b66b1b97437fb28cfc8ba38b9f497c78f4a09c17a7ab2
P = (Px,Py)

class Task(SocketServer.BaseRequestHandler):
	def proof_of_work(self):
		random.seed(os.urandom(8))
		proof = "".join([random.choice(string.ascii_letters+string.digits) for _ in range(20)])
		digest = sha256(proof).hexdigest()
		self.request.send("sha256(XXXX+%s) == %s\n" % (proof[4:],digest))
		self.request.send("Give me XXXX:")
		x = self.request.recv(10)
		x = x.strip()
		if len(x) != 4 or sha256(x+proof[4:]).hexdigest() != digest: 
			return False
		return True

	def recvall(self, sz):
		try:
			r = sz
			res = ""
			while r > 0:
				res += self.request.recv(r)
				if res.endswith("\n"):
					r = 0
				else:
					r = sz - len(res)
			res = res.strip()
		except:
			res = ""
		return res.strip("\n")

	def dosend(self, msg):
		try:
			self.request.sendall(msg)
		except:
			pass
	

	def handle(self):
		try:
			if not self.proof_of_work():
				return
			signal.alarm(300)
			self.secret = random.randint(0,q)
			Q = mul(self.secret,P)

			self.dosend("Welcome to the ECDH System.\n")   
			self.dosend("The params are: \n")
			self.dosend("q: " + str(q) + "\n")
			self.dosend("a: " + str(a) + "\n")
			self.dosend("b: " + str(b) + "\n")
			self.dosend("P: " + pointToString(P) + "\n")
			self.dosend("Q: " + pointToString(Q) + "\n")
			self.exchange()
			for _ in range(90):
				self.dosend("Tell me your choice:\n")
				choice = self.recvall(9)
				if choice == "Exchange":
					self.exchange()
				elif choice == "Encrypt":
					self.encrypt()
				elif choice == "Backdoor":
					self.backdoor()
				else:
					self.dosend("No such choice!\n")
			self.dosend("Bye bye~\n")	 
			self.request.close()
		except:
			self.dosend("Something error!\n")
			self.request.close()

	def pad(self,m):
		pad_length = q.bit_length()*2 - len(m)
		for _ in range(pad_length):
			m.insert(0,0)
		return m

	def encrypt(self):
		self.dosend("Give me your message(hex):\n") 
		msg = self.recvall(150)
		data = [int(i) for i in list('{0:0b}'.format(bytes_to_long(msg.decode("hex"))))] 
		enc = [data[i] ^ self.key[i%len(self.key)] for i in range(len(data))]
		result = 0
		for bit in enc:
			result = (result << 1) | bit
		result =  long_to_bytes(result).encode("hex")
		self.dosend("The result is:\n") 
		self.dosend(result + "\n") 

	def pointToKeys(self,p):
		x = p[0]
		y = p[1]
		tmp = x << q.bit_length() | y
		res = self.pad([int(i) for i in list('{0:0b}'.format(tmp))]) 
		return res

	def exchange(self):
		self.dosend("Give me your key:\n") 
		self.dosend("X:\n") 
		x = int(self.recvall(80))
		self.dosend("Y:\n") 
		y = int(self.recvall(80))
		key = (x,y)
		result = mul(self.secret,key)
		self.key = self.pointToKeys(result)
		self.dosend("Exchange success\n") 

	def backdoor(self):
		self.dosend("Give me the secret:\n") 
		s = self.recvall(80)
		if int(s) == self.secret:
			self.dosend('Wow! How smart you are! Here is your flag:\n')
			self.dosend(flag)
		else:
			self.dosend('Sorry you are wrong!\n')
		exit(0)

class ForkedServer(SocketServer.ForkingTCPServer, SocketServer.TCPServer):
    pass


if __name__ == "__main__":
    HOST, PORT = "0.0.0.0", 8848
    server = ForkedServer((HOST, PORT), Task)
    server.allow_reuse_address = True
    server.serve_forever()

So the server is defining a custom elliptic curve:

\(y^2 = x^3 + ax + b\) over \(F_q\)

where:

a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd

And a generator \(P = (x, y)\) with:

x = 0xb55c08d92cd878a3ad444a3627a52764f5a402f4a86ef700271cb17edfa739ca
y = 0x49ee01169c130f25853b66b1b97437fb28cfc8ba38b9f497c78f4a09c17a7ab2

The main iter of the server is:

  1. Ask for a Proof of Work where we need to find an easy collision on sha256.
  2. Generate a random secret between \(0, q\).
  3. Compute \(Q = P * secret\).
  4. Exchange operation.
  5. For 90 times 1 operation between Exchange, Encrypt, Backdoor.

Code of Exchange:

def exchange(self):
    self.dosend("Give me your key:\n") 
    self.dosend("X:\n") 
    x = int(self.recvall(80))
    self.dosend("Y:\n") 
    y = int(self.recvall(80))
    key = (x,y)
    result = mul(self.secret,key)
    self.key = self.pointToKeys(result)
    self.dosend("Exchange success\n")

So the server asks the coordinate of a point (let’s call this point \(G\)), and then compute key as \(key = G * secret\).

Code of Encrypt:

def encrypt(self):
    self.dosend("Give me your message(hex):\n") 
    msg = self.recvall(150)
    data = [int(i) for i in list('{0:0b}'.format(bytes_to_long(msg.decode("hex"))))] 
    enc = [data[i] ^ self.key[i%len(self.key)] for i in range(len(data))]
    result = 0
    for bit in enc:
        result = (result << 1) | bit
    result =  long_to_bytes(result).encode("hex")
    self.dosend("The result is:\n") 
    self.dosend(result + "\n") 

The server asks for a message and then xor our message with the key. If we xor back the result with our message we obtain the key.

Code of Backdoor:

def backdoor(self):
    self.dosend("Give me the secret:\n") 
    s = self.recvall(80)
    if int(s) == self.secret:
        self.dosend('Wow! How smart you are! Here is your flag:\n')
        self.dosend(flag)
    else:
        self.dosend('Sorry you are wrong!\n')
    exit(0)

So if we guess/find correctly the secret we can get the flag.

However finding \(k\) from \(Q, P\) where \(Q = P * k\) is known as the discrete logarithm problem and is very hard to compute in polynomial time. (In our case \(k\) is the secret).

I found this paper which describe various attacks on ECC, the one I was interested in is the Smart’s attack which is possible when the order of the curve is equal to \(q\).

We can test this in sage:

a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa 
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8 
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd 
E = EllipticCurve(GF(q), [a, b])
print(E.order())
print(q)
print(E.order() == q)

Output:

100173830297345234246296808618734115432562869561923600152135698390163541578931
100173830297345234246296808618734115432031228596757431606170574500462062623677
False

Ok no smart’s attack.

We tried to set the point \((0, 0)\) during the exchange and then encrypting a random message. We saw that the ciphertext was equal to the message, so the key was set to 0. Then I wanted to try some basic multiplications on sage using the point \((0, 1)\) but:

G = E([0, 1])
TypeError: Coordinates [0, 1, 1] do not define a point on Elliptic Curve defined by y^2 = x^3 + 34797263276731929172113534470189285565338055361167847976642146658712494152186*x + 39258950039257324692361857404792305546106163510884954845070423459288379905976 over Finite Field of size 100173830297345234246296808618734115432031228596757431606170574500462062623677

Right! The point \((0, 1)\) and \((0, 0)\) are not on the curve!

So what?

Well we can use the invalid curve attack to find the secret!

The server fails to validate that the point is on the curve, so the custom points have different order, and we can easily compute the discrete logarithm since the order of the new curve is very small.

Example using sage:

sage: E = EllipticCurve(GF(q), [a, 10]) # Define custom curve with b = 10
sage: E
Elliptic Curve defined by y^2 = x^3 + 34797263276731929172113534470189285565338055361167847976642146658712494152186*x + 10 over Finite Field of size 100173830297345234246296808618734115432031228596757431606170574500462062623677
sage: primes = prime_factors(E.order())
sage: primes
[2,
 7,
 43,
 41600427864346027510920601585853037970001163068300848417938134241287552041]
sage: prime = 7 # We want to generate a point which has order 7

sage: G = E.gen(0) * int(E.order() / prime)
sage: G # Base point of the curve
(23800613574902822049829027077760684224803559019003054675408332590180803344147 : 4985144212684407923452143080073709767923132482274656621285455512784711650956 : 1)

sage: for i in range(10): 
....:     print(G * i) 
....:

(0 : 1 : 0)
(23800613574902822049829027077760684224803559019003054675408332590180803344147 : 4985144212684407923452143080073709767923132482274656621285455512784711650956 : 1)
(89600239463847691447848720216683885818123916754582881936284294745413040375092 : 19776018809317549880256385419609476649557176926595835150135592257544590182632 : 1)
(20714360539196670727950256174571184732410240575427565638221344028890014550894 : 88148676486914078619821107426715231757947654441536575462503069115357865865972 : 1)
(20714360539196670727950256174571184732410240575427565638221344028890014550894 : 12025153810431155626475701192018883674083574155220856143667505385104196757705 : 1)
(89600239463847691447848720216683885818123916754582881936284294745413040375092 : 80397811488027684366040423199124638782474051670161596456034982242917472441045 : 1)
(23800613574902822049829027077760684224803559019003054675408332590180803344147 : 95188686084660826322844665538660405664108096114482774984885118987677350972721 : 1)
(0 : 1 : 0)
(23800613574902822049829027077760684224803559019003054675408332590180803344147 : 4985144212684407923452143080073709767923132482274656621285455512784711650956 : 1)
(89600239463847691447848720216683885818123916754582881936284294745413040375092 : 19776018809317549880256385419609476649557176926595835150135592257544590182632 : 1)

We chose as order 7, and then we compute a generator \(G\) of the curve. Then we have computed \(G * i\) for every i in range(10).

We can see that after the firsts 7 multiplication, the points are repeating. This is because we are in a cyclic group.

How can this help?

Well in this case computing discrete log is very easy because there are few possible cases (7).

If we now submit this point in the Exchange to the server and then we retrieve the key it’s trivial to find the secret?

Well no. What we will found is not the secret because the secret is between \((0, q)\) and is presumably around 250/255 bits long and not between 0 and 7.

What we can do is to try retrieving various pair (dlog, prime) and create a system of linear congruences that are solvable using the Chinese Remainder Theorem.

dlog is the result of the discrete logarithm and prime is the order of the generator.

#!/usr/bin/env python3
# CLIENT

import sys
from pwn import *
from Crypto.Util.number import bytes_to_long as bl
from factordb.factordb import FactorDB
from Crypto.Util.number import long_to_bytes as lb
from sage.all import *
import string

def crack(to_crack):
    alphabet = string.digits + string.ascii_letters
    for x in alphabet:
        for y in alphabet:
            for z in alphabet:
                for w in alphabet:
                    key = x+y+z+w + secret
                    _hash = hashlib.sha256(key.encode()).hexdigest()
                    if _hash == to_crack:
                        return x + y + z + w

def xor_string(s1, s2):
    return bytes(x ^ y for x, y in zip(s1, s2))

def pad(m):
    pad_length = 256*2 - len(m)
    for _ in range(pad_length):
        m.insert(0,0)
        return m

def point_to_keys(p):
    x = p[0]
    y = p[1]
    tmp = x << 256 | y
    res = pad([int(i) for i in list('{0:0b}'.format(tmp))]) 
    return res

def bytes_to_bit(a):
    l = []
    for b in a: 
        c = bin(b)[2:].rjust(8, '0') 
        for d in c: 
            l.append(int(d))
    return l

def keys_to_point(k):
    y = int(''.join(str(i) for i in k[-256:]), 2)
    x = int(''.join(str(i) for i in k[:256]), 2)
    return (x, y)

def read_sage_point(P):
    tmp = P.replace('(', '').replace(')', '').split(' : ')
    return (tmp[0], tmp[1])

a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd

# Solve PoW
conn = remote('134.175.225.42', 8848)
conn.recv(12)
secret = conn.recvuntil(b')')[:-1].decode()
conn.recvuntil(b'== ')
to_find = conn.recvline().strip(b'\n').decode()
print(secret)
print(to_find)
hash = crack(to_find)
conn.recvuntil(b':')
conn.sendline(hash)

context.log_level = 'DEBUG'

pp = []
dlog = []
blacklist_dlog = [30, 42]
for i in range(10):

    # Exchange like a beast
    print("\n\nEXCHANGE\n")
    # G
    # Q = G * secret
    b = randint(10, 50)
    print(f"b: {b}")
    E = EllipticCurve(GF(q), [a, b])
    print(f"Curve: {E}")
    order = E.order()
    print("computing prime factors")
    primes = prime_factors(order)
    valid = []
    # 12298541720533
    for p in primes:
        if p <= 122985417205330:
            valid.append(p)
    prime = valid[-1:][0]
    pp.append(prime)
    print(f"prime order: {prime}")
    print("computing generator")
    G = E.gen(0) * int(order / prime)
    G = read_sage_point(str(G))
    if i != 0:
        conn.recvuntil(b'choice:\n')
        conn.sendline(b'Exchange')
    conn.recvuntil(b'X:')
    conn.sendline(str(G[0]))
    conn.recvuntil(b'Y:')
    conn.sendline(str(G[1]))

    # Solve DLOG muthafucka 
    conn.recvuntil(b'Tell me your choice:')
    conn.sendline('Encrypt')
    conn.recvuntil('Give me your message(hex):')
    to_cipher = ""
    to_cipher += 'f1'
    for _ in range(63):
        to_cipher += '41'
    conn.sendline(to_cipher)

    print(conn.recvuntil(b'is:\n'))
    res = conn.recvline().strip(b'\n').decode()
    key = xor_string(bytes.fromhex(to_cipher), bytes.fromhex(res))
    Q = keys_to_point(bytes_to_bit(key))
    print(f"Q: {Q}")
    print(f"G: {G}")
    G_e = E([G[0], G[1]])
    Q_e = E([Q[0], Q[1]])
    print("computing dlog")
    logg = G_e.discrete_log(Q_e)
    dlog.append(logg)
    print(f"dlog: {logg}")
    print(f"done: {i}") 

print(f"dlog: {dlog}")
print(f"primes: {pp}")
super_secret = CRT_list(dlog, pp)
print(super_secret)

conn.interactive()

The main problem was time. The server has a connection timeout of 5 minutes and we got 3 main bottlenecks:

  1. Computing prime factors of the custom curve.
  2. Find a generator of the curve.
  3. Solving dlog.

What we did was:

  1. Use factordb to factorize the order of the curve.
  2. Cache various generator using an external script.
  3. Solve the discrete logarithm in a separate process.

Exploit

Generator script:

generators = {}

for i in range(2, 50):
    if generators.get(i, None) is None:
        E = EllipticCurve(GF(q), [a, i])
        order = E.order()
        f = FactorDB(order)
        f.connect()
        primes = f.get_factor_list()
        valid = 0
        for p in primes:
            if p <= 122985417205330:
                valid = p
        prime = valid
        G = E.gen(0) * int(order / prime)
        G = read_sage_point(str(G))
        generators[i] = (G[0], G[1], prime)
        print(generators)

Exploit:

#!/usr/bin/env python3

import sys
from pwn import *
from Crypto.Util.number import bytes_to_long as bl
from factordb.factordb import FactorDB
from Crypto.Util.number import long_to_bytes as lb
from sage.all import *
import multiprocessing
import string
import time

def crack(to_crack):
    alphabet = string.digits + string.ascii_letters
    for x in alphabet:
        for y in alphabet:
            for z in alphabet:
                for w in alphabet:
                    key = x+y+z+w + secret
                    _hash = hashlib.sha256(key.encode()).hexdigest()
                    if _hash == to_crack:
                        return x + y + z + w

def xor_string(s1, s2):
    return bytes(x ^ y for x, y in zip(s1, s2))

def pad(m):
    pad_length = 256*2 - len(m)
    for _ in range(pad_length):
        m.insert(0,0)
        return m

def point_to_keys(p):
    x = p[0]
    y = p[1]
    tmp = x << 256 | y
    res = pad([int(i) for i in list('{0:0b}'.format(tmp))]) 
    return res

def bytes_to_bit(a):
    l = []
    for b in a: 
        c = bin(b)[2:].rjust(8, '0') 
        for d in c: 
            l.append(int(d))
    return l

def keys_to_point(k):
    y = int(''.join(str(i) for i in k[-256:]), 2)
    x = int(''.join(str(i) for i in k[:256]), 2)
    return (x, y)

def read_sage_point(P):
    tmp = P.replace('(', '').replace(')', '').split(' : ')
    return (tmp[0], tmp[1])

def compute_dlog(G, Q, prime, crt_solver):
    print(f"trying to crack {Q} = {G} * secret")
    dlog = G.discrete_log(Q)
    crt_solver[prime] = dlog
    print(f"done dlog: {dlog}")

a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd

# Solve PoW
context.log_level = 'DEBUG'
conn = remote('134.175.225.42', 8848)

conn.recv(12)
secret = conn.recvuntil(b')')[:-1].decode()
conn.recvuntil(b'== ') to_find = conn.recvline().strip(b'\n').decode()
print(secret)
print(to_find)
hash = crack(to_find)
conn.recvuntil(b':')
conn.sendline(hash)


# We take only the points with a medium/high order.

generators = {
  2: ('81972374812864690693014852714727745978362841318289396397361414563375734501921', '51896212619465579275373796571948358799405679959558638822560270744499831979432', 1327481),
  3: ('53326919872607738273172644810428986036094490647068694645393694893278040069026', '58081551953623943565565755936599269753986052376062491850282542721957212970679', 71337711241),
  4: ('8148101487581648484053574525912222209338990886522507033183350100344852533064', '15225235739470967392827134243865327994823593728240653822041518660024750505853', 287038167079),
  8: ('50666337867992575803820658750430197277695458978329092014096425844774763149719', '13421860369980853123427220383308870579014596017199802188611416857274407341576', 14007550114741),
  9: ('27693266742860373879830235866293233339575293329402286516010276136420544358905', '8279872853041474338063561217165295999844712166288430478991978486297227230411', 17003163806851),
  11: ('73177225050788643134380595317533550828476603067177547542601837469105448675189', '16845507613964658793645291408091138189402262042994290641383029874357951799263', 80599),
  12: ('87927856259936467933865576887569931721798540127844175083299758091155998566253', '73204268729977825848740410829132875207119846783884288746044087556556009530393', 19813513),
  14: ('20443758771121867411286144646166868372450547941066474097899222621364635384017', '71171557090899672647440713736403327367383103843027405297306364952443383551107', 20731),
  15: ('89027666840299255236885552248007867044183365676188167946721016933234594434812', '90004068319753094768603992580255539598521077777716506807086729772395186348302', 119653)
}

manager = multiprocessing.Manager()
crt_solver = manager.dict()
processes = []
for i in range(len(list(generators.keys()))):

    # Exchange like a beast
    b = list(generators.keys())[i]
    E = EllipticCurve(GF(q), [a, b])
    G = generators[b][:2]
    print("\n\nEXCHANGE\n")
    print(f"trying b: {b}")

    if i != 0:
        conn.recvuntil(b'choice:\n')
        conn.sendline(b'Exchange')
    conn.recvuntil(b'X:')
    conn.sendline(str(G[0]))
    conn.recvuntil(b'Y:')
    conn.sendline(str(G[1]))

    # Solve DLOG muthafucka 
    conn.recvuntil(b'Tell me your choice:')
    conn.sendline('Encrypt')
    conn.recvuntil('Give me your message(hex):')
    to_cipher = ""
    to_cipher += 'f1'
    for _ in range(63):
        to_cipher += '41'
    conn.sendline(to_cipher)

    print(conn.recvuntil(b'is:\n'))
    res = conn.recvline().strip(b'\n').decode()
    key = xor_string(bytes.fromhex(to_cipher), bytes.fromhex(res))
    Q = keys_to_point(bytes_to_bit(key))
    print(f"Q: {Q}")
    print(f"G: {G}")
    G_e = E([G[0], G[1]])
    Q_e = E([Q[0], Q[1]])
    proc = multiprocessing.Process(target=compute_dlog, args=(G_e, Q_e, generators[b][2], crt_solver))
    processes.append(proc)
    proc.start()
    print(f"done: {i}") 

# Useless but fun
for _ in range(6):
    conn.recvuntil(b'choice:\n')
    conn.sendline(b'Exchange')
    conn.recvuntil(b'X:')
    conn.sendline(str(1))
    conn.recvuntil(b'Y:')
    conn.sendline(str(1))
    time.sleep(30)

for p in processes:
    p.join()

super_secret = CRT_list(list(crt_solver.values()), list(crt_solver.keys()))
print(super_secret)

conn.recvuntil(b'choice:\n')
conn.sendline(b'Backdoor')
conn.sendline(str(super_secret))
conn.interactive()

And after a while..

YEAH! We managed just in time to get the flag, fiu.

Flag

De1CTF{c47b5984-1a7c-49f5-a2e3-525d83b50ecf}

Fails

We spent many time trying to parse correctly the key and to convert it to a point and the xor gave us some problems.