Invalid curve attack on elliptic curve.
Information
- category : crypto
- points : 338
Description
Exchange your keys
nc 134.175.225.42 8848
1 file: task.py
Writeup
I solved this challenge together with The_Lillo and 0ssigeno.
We are given a task.py
which is the script running on the server:
import os,random,sys,string
from hashlib import sha256
import SocketServer
import signal
from FLAG import flag
from gmpy2 import invert
from Crypto.Util.number import bytes_to_long, long_to_bytes
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd
a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
zero = (0,0)
def add(p1,p2):
if p1 == zero:
return p2
if p2 == zero:
return p1
(p1x,p1y),(p2x,p2y) = p1,p2
if p1x == p2x and (p1y != p2y or p1y == 0):
return zero
if p1x == p2x:
tmp = (3 * p1x * p1x + a) * invert(2 * p1y , q) % q
else:
tmp = (p2y - p1y) * invert(p2x - p1x , q) % q
x = (tmp * tmp - p1x - p2x) % q
y = (tmp * (p1x - x) - p1y) % q
return (int(x),int(y))
def mul(n,p):
r = zero
tmp = p
while 0 < n:
if n & 1 == 1:
r = add(r,tmp)
n, tmp = n >> 1, add(tmp,tmp)
return r
def pointToString(p):
return "(" + str(p[0]) + "," + str(p[1]) + ")"
Px = 0xb55c08d92cd878a3ad444a3627a52764f5a402f4a86ef700271cb17edfa739ca
Py = 0x49ee01169c130f25853b66b1b97437fb28cfc8ba38b9f497c78f4a09c17a7ab2
P = (Px,Py)
class Task(SocketServer.BaseRequestHandler):
def proof_of_work(self):
random.seed(os.urandom(8))
proof = "".join([random.choice(string.ascii_letters+string.digits) for _ in range(20)])
digest = sha256(proof).hexdigest()
self.request.send("sha256(XXXX+%s) == %s\n" % (proof[4:],digest))
self.request.send("Give me XXXX:")
x = self.request.recv(10)
x = x.strip()
if len(x) != 4 or sha256(x+proof[4:]).hexdigest() != digest:
return False
return True
def recvall(self, sz):
try:
r = sz
res = ""
while r > 0:
res += self.request.recv(r)
if res.endswith("\n"):
r = 0
else:
r = sz - len(res)
res = res.strip()
except:
res = ""
return res.strip("\n")
def dosend(self, msg):
try:
self.request.sendall(msg)
except:
pass
def handle(self):
try:
if not self.proof_of_work():
return
signal.alarm(300)
self.secret = random.randint(0,q)
Q = mul(self.secret,P)
self.dosend("Welcome to the ECDH System.\n")
self.dosend("The params are: \n")
self.dosend("q: " + str(q) + "\n")
self.dosend("a: " + str(a) + "\n")
self.dosend("b: " + str(b) + "\n")
self.dosend("P: " + pointToString(P) + "\n")
self.dosend("Q: " + pointToString(Q) + "\n")
self.exchange()
for _ in range(90):
self.dosend("Tell me your choice:\n")
choice = self.recvall(9)
if choice == "Exchange":
self.exchange()
elif choice == "Encrypt":
self.encrypt()
elif choice == "Backdoor":
self.backdoor()
else:
self.dosend("No such choice!\n")
self.dosend("Bye bye~\n")
self.request.close()
except:
self.dosend("Something error!\n")
self.request.close()
def pad(self,m):
pad_length = q.bit_length()*2 - len(m)
for _ in range(pad_length):
m.insert(0,0)
return m
def encrypt(self):
self.dosend("Give me your message(hex):\n")
msg = self.recvall(150)
data = [int(i) for i in list('{0:0b}'.format(bytes_to_long(msg.decode("hex"))))]
enc = [data[i] ^ self.key[i%len(self.key)] for i in range(len(data))]
result = 0
for bit in enc:
result = (result << 1) | bit
result = long_to_bytes(result).encode("hex")
self.dosend("The result is:\n")
self.dosend(result + "\n")
def pointToKeys(self,p):
x = p[0]
y = p[1]
tmp = x << q.bit_length() | y
res = self.pad([int(i) for i in list('{0:0b}'.format(tmp))])
return res
def exchange(self):
self.dosend("Give me your key:\n")
self.dosend("X:\n")
x = int(self.recvall(80))
self.dosend("Y:\n")
y = int(self.recvall(80))
key = (x,y)
result = mul(self.secret,key)
self.key = self.pointToKeys(result)
self.dosend("Exchange success\n")
def backdoor(self):
self.dosend("Give me the secret:\n")
s = self.recvall(80)
if int(s) == self.secret:
self.dosend('Wow! How smart you are! Here is your flag:\n')
self.dosend(flag)
else:
self.dosend('Sorry you are wrong!\n')
exit(0)
class ForkedServer(SocketServer.ForkingTCPServer, SocketServer.TCPServer):
pass
if __name__ == "__main__":
HOST, PORT = "0.0.0.0", 8848
server = ForkedServer((HOST, PORT), Task)
server.allow_reuse_address = True
server.serve_forever()
So the server is defining a custom elliptic curve:
\(y^2 = x^3 + ax + b\) over \(F_q\)
where:
a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd
And a generator \(P = (x, y)\) with:
x = 0xb55c08d92cd878a3ad444a3627a52764f5a402f4a86ef700271cb17edfa739ca
y = 0x49ee01169c130f25853b66b1b97437fb28cfc8ba38b9f497c78f4a09c17a7ab2
The main iter of the server is:
- Ask for a Proof of Work where we need to find an easy collision on sha256.
- Generate a random secret between \(0, q\).
- Compute \(Q = P * secret\).
Exchange
operation.- For 90 times 1 operation between
Exchange, Encrypt, Backdoor
.
Code of Exchange
:
def exchange(self):
self.dosend("Give me your key:\n")
self.dosend("X:\n")
x = int(self.recvall(80))
self.dosend("Y:\n")
y = int(self.recvall(80))
key = (x,y)
result = mul(self.secret,key)
self.key = self.pointToKeys(result)
self.dosend("Exchange success\n")
So the server asks the coordinate of a point (let’s call this point \(G\)),
and then compute key
as \(key = G * secret\).
Code of Encrypt
:
def encrypt(self):
self.dosend("Give me your message(hex):\n")
msg = self.recvall(150)
data = [int(i) for i in list('{0:0b}'.format(bytes_to_long(msg.decode("hex"))))]
enc = [data[i] ^ self.key[i%len(self.key)] for i in range(len(data))]
result = 0
for bit in enc:
result = (result << 1) | bit
result = long_to_bytes(result).encode("hex")
self.dosend("The result is:\n")
self.dosend(result + "\n")
The server asks for a message and then xor our message with the key
. If we xor
back the result with our message we obtain the key
.
Code of Backdoor
:
def backdoor(self):
self.dosend("Give me the secret:\n")
s = self.recvall(80)
if int(s) == self.secret:
self.dosend('Wow! How smart you are! Here is your flag:\n')
self.dosend(flag)
else:
self.dosend('Sorry you are wrong!\n')
exit(0)
So if we guess/find correctly the secret we can get the flag.
However finding \(k\) from \(Q, P\) where \(Q = P * k\) is known as the discrete logarithm problem and is very hard to compute in polynomial time. (In our case \(k\) is the secret).
I found this paper which describe various attacks on ECC, the one I was interested in is the Smart’s attack which is possible when the order of the curve is equal to \(q\).
We can test this in sage:
a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd
E = EllipticCurve(GF(q), [a, b])
print(E.order())
print(q)
print(E.order() == q)
Output:
100173830297345234246296808618734115432562869561923600152135698390163541578931
100173830297345234246296808618734115432031228596757431606170574500462062623677
False
Ok no smart’s attack.
We tried to set the point \((0, 0)\) during the exchange and then encrypting a
random message. We saw that the ciphertext was equal to the message, so the key
was set to 0. Then I wanted to try some basic multiplications on sage using the
point \((0, 1)\) but:
G = E([0, 1])
TypeError: Coordinates [0, 1, 1] do not define a point on Elliptic Curve defined by y^2 = x^3 + 34797263276731929172113534470189285565338055361167847976642146658712494152186*x + 39258950039257324692361857404792305546106163510884954845070423459288379905976 over Finite Field of size 100173830297345234246296808618734115432031228596757431606170574500462062623677
Right! The point \((0, 1)\) and \((0, 0)\) are not on the curve!
So what?
Well we can use the invalid curve attack to find the secret!
The server fails to validate that the point is on the curve, so the custom points have different order, and we can easily compute the discrete logarithm since the order of the new curve is very small.
Example using sage:
sage: E = EllipticCurve(GF(q), [a, 10]) # Define custom curve with b = 10
sage: E
Elliptic Curve defined by y^2 = x^3 + 34797263276731929172113534470189285565338055361167847976642146658712494152186*x + 10 over Finite Field of size 100173830297345234246296808618734115432031228596757431606170574500462062623677
sage: primes = prime_factors(E.order())
sage: primes
[2,
7,
43,
41600427864346027510920601585853037970001163068300848417938134241287552041]
sage: prime = 7 # We want to generate a point which has order 7
sage: G = E.gen(0) * int(E.order() / prime)
sage: G # Base point of the curve
(23800613574902822049829027077760684224803559019003054675408332590180803344147 : 4985144212684407923452143080073709767923132482274656621285455512784711650956 : 1)
sage: for i in range(10):
....: print(G * i)
....:
(0 : 1 : 0)
(23800613574902822049829027077760684224803559019003054675408332590180803344147 : 4985144212684407923452143080073709767923132482274656621285455512784711650956 : 1)
(89600239463847691447848720216683885818123916754582881936284294745413040375092 : 19776018809317549880256385419609476649557176926595835150135592257544590182632 : 1)
(20714360539196670727950256174571184732410240575427565638221344028890014550894 : 88148676486914078619821107426715231757947654441536575462503069115357865865972 : 1)
(20714360539196670727950256174571184732410240575427565638221344028890014550894 : 12025153810431155626475701192018883674083574155220856143667505385104196757705 : 1)
(89600239463847691447848720216683885818123916754582881936284294745413040375092 : 80397811488027684366040423199124638782474051670161596456034982242917472441045 : 1)
(23800613574902822049829027077760684224803559019003054675408332590180803344147 : 95188686084660826322844665538660405664108096114482774984885118987677350972721 : 1)
(0 : 1 : 0)
(23800613574902822049829027077760684224803559019003054675408332590180803344147 : 4985144212684407923452143080073709767923132482274656621285455512784711650956 : 1)
(89600239463847691447848720216683885818123916754582881936284294745413040375092 : 19776018809317549880256385419609476649557176926595835150135592257544590182632 : 1)
We chose as order 7
, and then we compute a generator \(G\) of the curve.
Then we have computed \(G * i\) for every i in range(10).
We can see that after the firsts 7
multiplication, the points are repeating. This
is because we are in a cyclic group.
How can this help?
Well in this case computing discrete log is very easy because there are few possible cases (7).
If we now submit this point in the Exchange
to the server and then we retrieve
the key it’s trivial to find the secret?
Well no. What we will found is not the secret because the secret is between \((0, q)\) and is presumably around 250/255 bits long and not between 0 and 7.
What we can do is to try retrieving various pair (dlog, prime)
and create a
system of linear congruences that are solvable using the
Chinese Remainder Theorem.
dlog
is the result of the discrete logarithm and prime
is the order of the
generator.
#!/usr/bin/env python3
# CLIENT
import sys
from pwn import *
from Crypto.Util.number import bytes_to_long as bl
from factordb.factordb import FactorDB
from Crypto.Util.number import long_to_bytes as lb
from sage.all import *
import string
def crack(to_crack):
alphabet = string.digits + string.ascii_letters
for x in alphabet:
for y in alphabet:
for z in alphabet:
for w in alphabet:
key = x+y+z+w + secret
_hash = hashlib.sha256(key.encode()).hexdigest()
if _hash == to_crack:
return x + y + z + w
def xor_string(s1, s2):
return bytes(x ^ y for x, y in zip(s1, s2))
def pad(m):
pad_length = 256*2 - len(m)
for _ in range(pad_length):
m.insert(0,0)
return m
def point_to_keys(p):
x = p[0]
y = p[1]
tmp = x << 256 | y
res = pad([int(i) for i in list('{0:0b}'.format(tmp))])
return res
def bytes_to_bit(a):
l = []
for b in a:
c = bin(b)[2:].rjust(8, '0')
for d in c:
l.append(int(d))
return l
def keys_to_point(k):
y = int(''.join(str(i) for i in k[-256:]), 2)
x = int(''.join(str(i) for i in k[:256]), 2)
return (x, y)
def read_sage_point(P):
tmp = P.replace('(', '').replace(')', '').split(' : ')
return (tmp[0], tmp[1])
a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd
# Solve PoW
conn = remote('134.175.225.42', 8848)
conn.recv(12)
secret = conn.recvuntil(b')')[:-1].decode()
conn.recvuntil(b'== ')
to_find = conn.recvline().strip(b'\n').decode()
print(secret)
print(to_find)
hash = crack(to_find)
conn.recvuntil(b':')
conn.sendline(hash)
context.log_level = 'DEBUG'
pp = []
dlog = []
blacklist_dlog = [30, 42]
for i in range(10):
# Exchange like a beast
print("\n\nEXCHANGE\n")
# G
# Q = G * secret
b = randint(10, 50)
print(f"b: {b}")
E = EllipticCurve(GF(q), [a, b])
print(f"Curve: {E}")
order = E.order()
print("computing prime factors")
primes = prime_factors(order)
valid = []
# 12298541720533
for p in primes:
if p <= 122985417205330:
valid.append(p)
prime = valid[-1:][0]
pp.append(prime)
print(f"prime order: {prime}")
print("computing generator")
G = E.gen(0) * int(order / prime)
G = read_sage_point(str(G))
if i != 0:
conn.recvuntil(b'choice:\n')
conn.sendline(b'Exchange')
conn.recvuntil(b'X:')
conn.sendline(str(G[0]))
conn.recvuntil(b'Y:')
conn.sendline(str(G[1]))
# Solve DLOG muthafucka
conn.recvuntil(b'Tell me your choice:')
conn.sendline('Encrypt')
conn.recvuntil('Give me your message(hex):')
to_cipher = ""
to_cipher += 'f1'
for _ in range(63):
to_cipher += '41'
conn.sendline(to_cipher)
print(conn.recvuntil(b'is:\n'))
res = conn.recvline().strip(b'\n').decode()
key = xor_string(bytes.fromhex(to_cipher), bytes.fromhex(res))
Q = keys_to_point(bytes_to_bit(key))
print(f"Q: {Q}")
print(f"G: {G}")
G_e = E([G[0], G[1]])
Q_e = E([Q[0], Q[1]])
print("computing dlog")
logg = G_e.discrete_log(Q_e)
dlog.append(logg)
print(f"dlog: {logg}")
print(f"done: {i}")
print(f"dlog: {dlog}")
print(f"primes: {pp}")
super_secret = CRT_list(dlog, pp)
print(super_secret)
conn.interactive()
The main problem was time. The server has a connection timeout of 5 minutes and we got 3 main bottlenecks:
- Computing prime factors of the custom curve.
- Find a generator of the curve.
- Solving dlog.
What we did was:
- Use factordb to factorize the order of the curve.
- Cache various generator using an external script.
- Solve the discrete logarithm in a separate process.
Exploit
Generator script:
generators = {}
for i in range(2, 50):
if generators.get(i, None) is None:
E = EllipticCurve(GF(q), [a, i])
order = E.order()
f = FactorDB(order)
f.connect()
primes = f.get_factor_list()
valid = 0
for p in primes:
if p <= 122985417205330:
valid = p
prime = valid
G = E.gen(0) * int(order / prime)
G = read_sage_point(str(G))
generators[i] = (G[0], G[1], prime)
print(generators)
Exploit:
#!/usr/bin/env python3
import sys
from pwn import *
from Crypto.Util.number import bytes_to_long as bl
from factordb.factordb import FactorDB
from Crypto.Util.number import long_to_bytes as lb
from sage.all import *
import multiprocessing
import string
import time
def crack(to_crack):
alphabet = string.digits + string.ascii_letters
for x in alphabet:
for y in alphabet:
for z in alphabet:
for w in alphabet:
key = x+y+z+w + secret
_hash = hashlib.sha256(key.encode()).hexdigest()
if _hash == to_crack:
return x + y + z + w
def xor_string(s1, s2):
return bytes(x ^ y for x, y in zip(s1, s2))
def pad(m):
pad_length = 256*2 - len(m)
for _ in range(pad_length):
m.insert(0,0)
return m
def point_to_keys(p):
x = p[0]
y = p[1]
tmp = x << 256 | y
res = pad([int(i) for i in list('{0:0b}'.format(tmp))])
return res
def bytes_to_bit(a):
l = []
for b in a:
c = bin(b)[2:].rjust(8, '0')
for d in c:
l.append(int(d))
return l
def keys_to_point(k):
y = int(''.join(str(i) for i in k[-256:]), 2)
x = int(''.join(str(i) for i in k[:256]), 2)
return (x, y)
def read_sage_point(P):
tmp = P.replace('(', '').replace(')', '').split(' : ')
return (tmp[0], tmp[1])
def compute_dlog(G, Q, prime, crt_solver):
print(f"trying to crack {Q} = {G} * secret")
dlog = G.discrete_log(Q)
crt_solver[prime] = dlog
print(f"done dlog: {dlog}")
a = 0x4cee8d95bb3f64db7d53b078ba3a904557425e2a6d91c5dfbf4c564a3f3619fa
b = 0x56cbc73d8d2ad00e22f12b930d1d685136357d692fa705dae25c66bee23157b8
q = 0xdd7860f2c4afe6d96059766ddd2b52f7bb1ab0fce779a36f723d50339ab25bbd
# Solve PoW
context.log_level = 'DEBUG'
conn = remote('134.175.225.42', 8848)
conn.recv(12)
secret = conn.recvuntil(b')')[:-1].decode()
conn.recvuntil(b'== ') to_find = conn.recvline().strip(b'\n').decode()
print(secret)
print(to_find)
hash = crack(to_find)
conn.recvuntil(b':')
conn.sendline(hash)
# We take only the points with a medium/high order.
generators = {
2: ('81972374812864690693014852714727745978362841318289396397361414563375734501921', '51896212619465579275373796571948358799405679959558638822560270744499831979432', 1327481),
3: ('53326919872607738273172644810428986036094490647068694645393694893278040069026', '58081551953623943565565755936599269753986052376062491850282542721957212970679', 71337711241),
4: ('8148101487581648484053574525912222209338990886522507033183350100344852533064', '15225235739470967392827134243865327994823593728240653822041518660024750505853', 287038167079),
8: ('50666337867992575803820658750430197277695458978329092014096425844774763149719', '13421860369980853123427220383308870579014596017199802188611416857274407341576', 14007550114741),
9: ('27693266742860373879830235866293233339575293329402286516010276136420544358905', '8279872853041474338063561217165295999844712166288430478991978486297227230411', 17003163806851),
11: ('73177225050788643134380595317533550828476603067177547542601837469105448675189', '16845507613964658793645291408091138189402262042994290641383029874357951799263', 80599),
12: ('87927856259936467933865576887569931721798540127844175083299758091155998566253', '73204268729977825848740410829132875207119846783884288746044087556556009530393', 19813513),
14: ('20443758771121867411286144646166868372450547941066474097899222621364635384017', '71171557090899672647440713736403327367383103843027405297306364952443383551107', 20731),
15: ('89027666840299255236885552248007867044183365676188167946721016933234594434812', '90004068319753094768603992580255539598521077777716506807086729772395186348302', 119653)
}
manager = multiprocessing.Manager()
crt_solver = manager.dict()
processes = []
for i in range(len(list(generators.keys()))):
# Exchange like a beast
b = list(generators.keys())[i]
E = EllipticCurve(GF(q), [a, b])
G = generators[b][:2]
print("\n\nEXCHANGE\n")
print(f"trying b: {b}")
if i != 0:
conn.recvuntil(b'choice:\n')
conn.sendline(b'Exchange')
conn.recvuntil(b'X:')
conn.sendline(str(G[0]))
conn.recvuntil(b'Y:')
conn.sendline(str(G[1]))
# Solve DLOG muthafucka
conn.recvuntil(b'Tell me your choice:')
conn.sendline('Encrypt')
conn.recvuntil('Give me your message(hex):')
to_cipher = ""
to_cipher += 'f1'
for _ in range(63):
to_cipher += '41'
conn.sendline(to_cipher)
print(conn.recvuntil(b'is:\n'))
res = conn.recvline().strip(b'\n').decode()
key = xor_string(bytes.fromhex(to_cipher), bytes.fromhex(res))
Q = keys_to_point(bytes_to_bit(key))
print(f"Q: {Q}")
print(f"G: {G}")
G_e = E([G[0], G[1]])
Q_e = E([Q[0], Q[1]])
proc = multiprocessing.Process(target=compute_dlog, args=(G_e, Q_e, generators[b][2], crt_solver))
processes.append(proc)
proc.start()
print(f"done: {i}")
# Useless but fun
for _ in range(6):
conn.recvuntil(b'choice:\n')
conn.sendline(b'Exchange')
conn.recvuntil(b'X:')
conn.sendline(str(1))
conn.recvuntil(b'Y:')
conn.sendline(str(1))
time.sleep(30)
for p in processes:
p.join()
super_secret = CRT_list(list(crt_solver.values()), list(crt_solver.keys()))
print(super_secret)
conn.recvuntil(b'choice:\n')
conn.sendline(b'Backdoor')
conn.sendline(str(super_secret))
conn.interactive()
And after a while..
YEAH! We managed just in time to get the flag, fiu.
Flag
De1CTF{c47b5984-1a7c-49f5-a2e3-525d83b50ecf}
Fails
We spent many time trying to parse correctly the key and to convert it to a point and the xor gave us some problems.