Reverse an encryption algorithm with the help of radare2
Information
- category : reverse
- points : 301
- solves: 11
Description
I was playing with some parallel programming but I’m so bad with it that I lost the flag, can you help me recover it?
2aad2e5a49fb2d9adb908dd00eb48c8a6607ab619f75b0272f3c1eb33fe9edaf
Writeup
We have an ELF binary that performs the encryption of argv[1] using a custom algorithm.
The algorithm has two rounds:
- Spawns 14 threads that executes 14 different functions, however it doesn’t implement any kind of barrier so the 14 functions are executed in random order.
- As in 1) it spawns 14 threads but now it ensures that the first 7 functions are executed in the same order as 1), and the other in the reverse order of the last executed functions in 1)
Each round take a clean plaintext from argv[1] and do the encryption, so the ciphertext they gave us is:
E1(flag) = 2aad2e5a49fb2d9adb908dd00eb48c8a
E2(flag) = 6607ab619f75b0272f3c1eb33fe9edaf
To test that the functions are executed in that particular order in the two rounds I wrote a gdbscript:
import gdb
gdb.execute("set pagination off")
gdb.execute("b 760")
gdb.execute("b 744")
gdb.execute("r abcdefghijklmnop")
print("\n\nFIRST LOOP\n\n")
for i in range(14):
gdb.execute("x/i $rdx")
print(f"{i=}")
gdb.execute("continue")
print("\n\nSECOND LOOP\n\n")
for i in range(14):
gdb.execute("x/i $rdx")
print(f"{i=}")
gdb.execute("continue")
The breakpoints are set on source code line and that’s possible because I used ghidra2dwarf to export dwarf informations inside the binary.
From gdb:
(gdb) source gdbscript.py
Take the output and grep func[0-9]
to get:
first round
0x5555555554e4 <func1>: endbr64
0x555555555607 <func2>: endbr64
0x555555555678 <func3>: endbr64
0x5555555556ff <func4>: endbr64
0x55555555577d <func5>: endbr64
0x555555555821 <func6>: endbr64
0x5555555558f9 <func8>: endbr64
0x555555555974 <func9>: endbr64
0x5555555559db <func10>: endbr64
0x555555555a4e <func11>: endbr64
0x555555555c8c <func13>: endbr64
0x555555555b75 <func12>: endbr64
0x555555555bc5 <func14>: endbr64
0x555555555868 <func7>: endbr64
second round
0x5555555554e4 <func1>: endbr64
0x555555555607 <func2>: endbr64
0x555555555678 <func3>: endbr64
0x5555555556ff <func4>: endbr64
0x55555555577d <func5>: endbr64
0x555555555821 <func6>: endbr64
0x5555555558f9 <func8>: endbr64
0x555555555868 <func7>: endbr64 <--- here
0x555555555bc5 <func14>: endbr64
0x555555555b75 <func12>: endbr64
0x555555555c8c <func13>: endbr64
0x555555555a4e <func11>: endbr64
0x5555555559db <func10>: endbr64
0x555555555974 <func9>: endbr64
Strategy
Since finding the correct permutations of functions that encrypt the ciphertext
is too much 14!
, we tried to do some sort of meet in the middle.
- Write the inverse of each function
-
Try every possible permutations of 7 inverse functions and apply them to ciphertext1. There are
14 * 13 * 12 * 11 * 10 * 9 * 8 = 17297280
possible permutations -
For each permutations of 2), apply them in reverse order to ciphertext2
-
Since the first 7 applied functions are the same between ciphertext1 and ciphertext2 we should find the same encrypted message after the first 7 applied functions
- After we have found the last 7 applied functions, we can bruteforce the remaining 7,
since the possible permutations are only
7!
First step
To help writing the inverse of each function I’ve scripted a radare2 wrapper to emulate a function.
#!/usr/bin/env python3
import r2pipe
import time
import itertools as it
funcs_end = [
"0x0",
"0x000015f0", # func1
"0x00001675",
"0x000016fc",
"0x0000177a",
"0x0000180a",
"0x00001865",
"0x000018f6",
"0x00001971",
"0x000019d8",
"0x00001a4b",
"0x00001b72",
"0x00001bc2",
"0x00001ceb",
"0x00001c89" # func14
]
def xor(a, b):
return bytes(x^y for x,y in zip(a,b))
def write_arg(s):
r2.cmd(f"wx {s.hex()} @ 0x100")
# alternative way if "o" is used instead of "om"
# with open("./plaintext", "wb") as f:
# f.write(s)
def get_arg(log=False):
out = r2.cmd("px/16xb @ 0x100").split("\n")[1]
ret = out.split(" ")[1].replace(" ", "")
if log:
print(out, end="\n\n")
return ret
def exec_func(i, arg, log=True):
write_arg(arg)
if log:
print(f"BEFORE func{i}")
get_arg(log=log)
r2.cmd(f"s sym.func{i}; aei; aeim; aeip; aer rdi=0x100; aer rsi=0x10")
r2.cmd(f"aecu {funcs_end[int(i)]}")
if log:
print(f"AFTER func{i}")
ret = get_arg(log=log)
return ret
DEFAULT_ARG = b"ptm{testabcdefg}"
r2 = r2pipe.open("./parallel_bku", flags=["-e", "bin.cache=true"])
r2.cmd("aaa; o+ ./plaintext 0x100")
# alternative way
# r2.cmd("aaa; om `om~[2:0]` 0x100 0x10 0x100 rw test")
print(exec_func(1, DEFAULT_ARG, log=True))
print(exec_func(4, DEFAULT_ARG, log=True))
print(exec_func(9, DEFAULT_ARG, log=True))
Every function is called func<n>
where <n>
is the number of the function, and is
called as func<n>(s=copy of argv[1], 0x10=length of s)
.
So I mapped a file (plaintext) to the address 0x100
in rw mode with o+ ./plaintext 0x100
(alternative way was to use om
).
Then to execute (emulate):
s sym.func<n> # seek to function
aei # init esil state
aeim # initialize esil memory
aeip # set rip = current seek
aer rdi=0x100 # set rdi
aer rsi=0x10 # set rsi
aecu <address> # continue execution until <address>
In the latest radare2 version is also possible to set stack memory with argv and envp: https://twitter.com/trufae/status/1392886678180728832
Output example:
BEFORE func1
0x00000100 7074 6d7b 7465 7374 6162 6364 6566 677d ptm{testabcdefg}
AFTER func1
0x00000100 747b 6d70 6574 7374 6264 6361 667d 6765 t{mpetstbdcaf}ge
747b6d706574737462646361667d6765
BEFORE func4
0x00000100 7074 6d7b 7465 7374 6162 6364 6566 677d ptm{testabcdefg}
AFTER func4
0x00000100 7074 6d7b 7465 7374 1511 0610 1e0b 130d ptm{test........
70746d7b74657374151106101e0b130d
BEFORE func9
0x00000100 7074 6d7b 7465 7374 6162 6364 6566 677d ptm{testabcdefg}
AFTER func9
0x00000100 9a9e 97a5 9e8f 9d9e 8b8c 8d8e 8f90 91a7 ................
9a9e97a59e8f9d9e8b8c8d8e8f9091a7
Then we have written every inverse:
def func1_inv(s):
for i in range(0, len(s), 4):
s[i+1], s[i+3], s[i+2], s[i] = s[i], s[i+1], s[i+2], s[i+3]
def func2_inv(s):
for i in range(8):
s[i] ^= s[16 - i - 1]
def func3_inv(s):
for i in range(8):
s[i] = (((s[i] >> i) & 0xff) | (s[i] << (8 - i)) & 0xff)
for i in range(8):
s[i + 8] = (((s[i + 8] >> i) & 0xff) | (s[i + 8] << (8 - i)) & 0xff)
def func4_inv(s):
for i in range(8):
s[16 - i - 1] ^= s[i]
def func5_inv(s):
to_xor = b"{reverse_fake_flag}ptm"[:16]
for i in range(16):
s[i] ^= to_xor[i]
def func6_inv(s):
for i in range(16):
s[i] = (~s[i] & 0xff)
def func7_inv(s):
for i in range(16):
s[i] = int(bin(s[i])[2:].zfill(8)[::-1], 2)
def func8_inv(s):
for i in range(8):
s[16 - i - 1], s[i] = s[i], s[16 - i - 1]
def func9_inv(s):
for i in range(16):
s[i] = (s[i] + 256 - 42) % 256
def func10_inv(s):
for i in range(16):
s[i] = (s[i] << 4) & 0xff | (s[i] >> 4) & 0xf
# We then memoized this function
def func11_inv(s):
for i in range(16):
cur = bin(s[i])[2:].zfill(8)
s[i] = int(cur[0] + cur[4] + cur[1] + cur[5] + cur[2] + cur[6] + cur[3] + cur[7], 2)
def func12_inv(s):
for i in range(16):
s[i] = (~s[i] & 0xff) ^ i
def func13_inv(s):
for i in range(16):
s[i] = (s[i] + 256 - i) % 256
def func14_inv(s):
for i in range(16):
if ord('z') >= s[i] >= ord('a'):
s[i] -= 32
elif ord('Z') >= s[i] >= ord('A'):
s[i] += 32
funcs_inv = [
lambda x: x,
func1_inv,
func2_inv,
func3_inv,
func4_inv,
func5_inv,
func6_inv,
func7_inv,
func8_inv,
func9_inv,
func10_inv,
func11_inv,
func12_inv,
func13_inv,
func14_inv
]
And a test to confirm the correctness:
for f in range(1, 15):
print(f"TESTING func{f}")
for i in range(256):
msg = bytearray([i for _ in range(16)])
ret = exec_func(f, log=False, arg=msg)
x = bytearray(bytes.fromhex(ret))
funcs_inv[f](x)
assert x == msg
Second step
#!/usr/bin/env python3
import itertools as it
ciphertext = bytes.fromhex("2aad2e5a49fb2d9adb908dd00eb48c8a6607ab619f75b0272f3c1eb33fe9edaf")
ciphertext1 = ciphertext[:16]
ciphertext2 = ciphertext[16:]
counter = 0
for p in it.permutations([x for x in range(1, 15)], 7):
if counter % 100000 == 0:
print(counter)
s = bytearray(ciphertext1)
for i in p:
funcs_inv[i](s)
s1 = bytearray(ciphertext2)
for i in p[::-1]:
funcs_inv[i](s1)
if s1 == s:
print(p)
exit(0)
counter += 1
With pypy3
it took around two minute. Found the permutation: (8, 9, 14, 4, 10, 12, 7)
Third step
c = bytearray(ciphertext1)
for i in [8, 9, 14, 4, 10, 12, 7]:
funcs_inv[i](c)
possibilities = [1, 2, 3, 5, 6, 11, 13]
for p in it.permutations(possibilities):
enc = c[:]
for i in p:
funcs_inv[i](enc)
if enc.startswith(b"ptm"):
print(enc)
Exploit
#!/usr/bin/env python3
import itertools as it
def xor(a, b):
return bytes(x^y for x,y in zip(a,b))
def func1_inv(s):
for i in range(0, len(s), 4):
s[i+1], s[i+3], s[i+2], s[i] = s[i], s[i+1], s[i+2], s[i+3]
def func2_inv(s):
for i in range(8):
s[i] ^= s[16 - i - 1]
def func3_inv(s):
for i in range(8):
s[i] = (((s[i] >> i) & 0xff) | (s[i] << (8 - i)) & 0xff)
for i in range(8):
s[i + 8] = (((s[i + 8] >> i) & 0xff) | (s[i + 8] << (8 - i)) & 0xff)
def func4_inv(s):
for i in range(8):
s[16 - i - 1] ^= s[i]
def func5_inv(s):
to_xor = b"{reverse_fake_flag}ptm"[:16]
for i in range(16):
s[i] ^= to_xor[i]
def func6_inv(s):
for i in range(16):
s[i] = (~s[i] & 0xff)
def func7_inv(s):
for i in range(16):
s[i] = int(bin(s[i])[2:].zfill(8)[::-1], 2)
def func8_inv(s):
for i in range(8):
s[16 - i - 1], s[i] = s[i], s[16 - i - 1]
def func9_inv(s):
for i in range(16):
s[i] = (s[i] + 256 - 42) % 256
def func10_inv(s):
for i in range(16):
s[i] = (s[i] << 4) & 0xff | (s[i] >> 4) & 0xf
d11_inv = [0, 1, 4, 5, 16, 17, 20, 21, 64, 65, 68, 69, 80, 81, 84, 85, 2, 3, 6, 7, 18, 19, 22, 23, 66, 67, 70, 71, 82, 83, 86, 87, 8, 9, 12, 13, 24, 25, 28, 29, 72, 73, 76, 77, 88, 89, 92, 93, 10, 11, 14, 15, 26, 27, 30, 31, 74, 75, 78, 79, 90, 91, 94, 95, 32, 33, 36, 37, 48, 49, 52, 53, 96, 97, 100, 101, 112, 113, 116, 117, 34, 35, 38, 39, 50, 51, 54, 55, 98, 99, 102, 103, 114, 115, 118, 119, 40, 41, 44, 45, 56, 57, 60, 61, 104, 105, 108, 109, 120, 121, 124, 125, 42, 43, 46, 47, 58, 59, 62, 63, 106, 107, 110, 111, 122, 123, 126, 127, 128, 129, 132, 133, 144, 145, 148, 149, 192, 193, 196, 197, 208, 209, 212, 213, 130, 131, 134, 135, 146, 147, 150, 151, 194, 195, 198, 199, 210, 211, 214, 215, 136, 137, 140, 141, 152, 153, 156, 157, 200, 201, 204, 205, 216, 217, 220, 221, 138, 139, 142, 143, 154, 155, 158, 159, 202, 203, 206, 207, 218, 219, 222, 223, 160, 161, 164, 165, 176, 177, 180, 181, 224, 225, 228, 229, 240, 241, 244, 245, 162, 163, 166, 167, 178, 179, 182, 183, 226, 227, 230, 231, 242, 243, 246, 247, 168, 169, 172, 173, 184, 185, 188, 189, 232, 233, 236, 237, 248, 249, 252, 253, 170, 171, 174, 175, 186, 187, 190, 191, 234, 235, 238, 239, 250, 251, 254, 255]
def func11_inv(s):
for i in range(16):
s[i] = d11_inv[s[i]]
def func12_inv(s):
for i in range(16):
s[i] = (~s[i] & 0xff) ^ i
def func13_inv(s):
for i in range(16):
s[i] = (s[i] + 256 - i) % 256
def func14_inv(s):
for i in range(16):
if ord('z') >= s[i] >= ord('a'):
s[i] -= 32
elif ord('Z') >= s[i] >= ord('A'):
s[i] += 32
funcs_inv = [
lambda x: x,
func1_inv,
func2_inv,
func3_inv,
func4_inv,
func5_inv,
func6_inv,
func7_inv,
func8_inv,
func9_inv,
func10_inv,
func11_inv,
func12_inv,
func13_inv,
func14_inv
]
ciphertext = bytes.fromhex("2aad2e5a49fb2d9adb908dd00eb48c8a6607ab619f75b0272f3c1eb33fe9edaf")
ciphertext1 = ciphertext[:16]
ciphertext2 = ciphertext[16:]
# counter = 0
# for p in it.permutations([x for x in range(1, 15)], 7):
# if counter % 100000 == 0:
# print(counter)
# s = bytearray(ciphertext1)
# for i in p:
# funcs_inv[i](s)
# s1 = bytearray(ciphertext2)
# for i in p[::-1]:
# funcs_inv[i](s1)
# if s1 == s:
# print(p)
# exit(0)
# counter += 1
c = bytearray(ciphertext1)
for i in [8, 9, 14, 4, 10, 12, 7]:
funcs_inv[i](c)
possibilities = [1, 2, 3, 5, 6, 11, 13] # (8, 9, 14, 4, 10, 12, 7)
for p in it.permutations(possibilities):
enc = c[:]
for i in p:
funcs_inv[i](enc)
if enc.startswith(b"ptm"):
print(enc)
Flag
ptm{brut3_f0rc3}