Reverse an encryption algorithm with the help of radare2

Information

  • category : reverse
  • points : 301
  • solves: 11

Description

I was playing with some parallel programming but I’m so bad with it that I lost the flag, can you help me recover it?

2aad2e5a49fb2d9adb908dd00eb48c8a6607ab619f75b0272f3c1eb33fe9edaf

Writeup

We have an ELF binary that performs the encryption of argv[1] using a custom algorithm.

The algorithm has two rounds:

  1. Spawns 14 threads that executes 14 different functions, however it doesn’t implement any kind of barrier so the 14 functions are executed in random order.
  2. As in 1) it spawns 14 threads but now it ensures that the first 7 functions are executed in the same order as 1), and the other in the reverse order of the last executed functions in 1)

Each round take a clean plaintext from argv[1] and do the encryption, so the ciphertext they gave us is:

E1(flag) = 2aad2e5a49fb2d9adb908dd00eb48c8a
E2(flag) = 6607ab619f75b0272f3c1eb33fe9edaf

To test that the functions are executed in that particular order in the two rounds I wrote a gdbscript:

import gdb

gdb.execute("set pagination off")
gdb.execute("b 760")
gdb.execute("b 744")
gdb.execute("r abcdefghijklmnop")

print("\n\nFIRST LOOP\n\n")
for i in range(14):
    gdb.execute("x/i $rdx")
    print(f"{i=}")
    gdb.execute("continue")

print("\n\nSECOND LOOP\n\n")
for i in range(14):
    gdb.execute("x/i $rdx")
    print(f"{i=}")
    gdb.execute("continue")

The breakpoints are set on source code line and that’s possible because I used ghidra2dwarf to export dwarf informations inside the binary.

From gdb:

(gdb) source gdbscript.py

Take the output and grep func[0-9] to get:

   first round
   0x5555555554e4 <func1>:      endbr64 
   0x555555555607 <func2>:      endbr64 
   0x555555555678 <func3>:      endbr64 
   0x5555555556ff <func4>:      endbr64 
   0x55555555577d <func5>:      endbr64 
   0x555555555821 <func6>:      endbr64 
   0x5555555558f9 <func8>:      endbr64 
   0x555555555974 <func9>:      endbr64 
   0x5555555559db <func10>:     endbr64 
   0x555555555a4e <func11>:     endbr64 
   0x555555555c8c <func13>:     endbr64 
   0x555555555b75 <func12>:     endbr64 
   0x555555555bc5 <func14>:     endbr64 
   0x555555555868 <func7>:      endbr64

   second round
   0x5555555554e4 <func1>:      endbr64 
   0x555555555607 <func2>:      endbr64 
   0x555555555678 <func3>:      endbr64 
   0x5555555556ff <func4>:      endbr64 
   0x55555555577d <func5>:      endbr64 
   0x555555555821 <func6>:      endbr64 
   0x5555555558f9 <func8>:      endbr64 
   0x555555555868 <func7>:      endbr64   <--- here
   0x555555555bc5 <func14>:     endbr64 
   0x555555555b75 <func12>:     endbr64 
   0x555555555c8c <func13>:     endbr64 
   0x555555555a4e <func11>:     endbr64 
   0x5555555559db <func10>:     endbr64 
   0x555555555974 <func9>:      endbr64

Strategy

Since finding the correct permutations of functions that encrypt the ciphertext is too much 14!, we tried to do some sort of meet in the middle.

  1. Write the inverse of each function
  • Try every possible permutations of 7 inverse functions and apply them to ciphertext1. There are 14 * 13 * 12 * 11 * 10 * 9 * 8 = 17297280 possible permutations

  • For each permutations of 2), apply them in reverse order to ciphertext2

  • Since the first 7 applied functions are the same between ciphertext1 and ciphertext2 we should find the same encrypted message after the first 7 applied functions

  1. After we have found the last 7 applied functions, we can bruteforce the remaining 7, since the possible permutations are only 7!

First step

To help writing the inverse of each function I’ve scripted a radare2 wrapper to emulate a function.

#!/usr/bin/env python3

import r2pipe
import time
import itertools as it

funcs_end = [
    "0x0",
    "0x000015f0", # func1
    "0x00001675",
    "0x000016fc",
    "0x0000177a",
    "0x0000180a",
    "0x00001865",
    "0x000018f6",
    "0x00001971",
    "0x000019d8",
    "0x00001a4b",
    "0x00001b72",
    "0x00001bc2",
    "0x00001ceb",
    "0x00001c89" # func14
]

def xor(a, b):
    return bytes(x^y for x,y in zip(a,b))

def write_arg(s):
    r2.cmd(f"wx {s.hex()} @ 0x100")
    # alternative way if "o" is used instead of "om"
    # with open("./plaintext", "wb") as f:
    #     f.write(s)

def get_arg(log=False):
    out = r2.cmd("px/16xb @ 0x100").split("\n")[1]
    ret = out.split("  ")[1].replace(" ", "")
    if log:
        print(out, end="\n\n")
    return ret

def exec_func(i, arg, log=True):
    write_arg(arg)
    if log:
        print(f"BEFORE func{i}")
        get_arg(log=log)
    r2.cmd(f"s sym.func{i}; aei; aeim; aeip; aer rdi=0x100; aer rsi=0x10")
    r2.cmd(f"aecu {funcs_end[int(i)]}")
    if log:
        print(f"AFTER func{i}")
    ret = get_arg(log=log)
    return ret

DEFAULT_ARG = b"ptm{testabcdefg}"
r2 = r2pipe.open("./parallel_bku", flags=["-e", "bin.cache=true"])
r2.cmd("aaa; o+ ./plaintext 0x100")
# alternative way
# r2.cmd("aaa; om `om~[2:0]` 0x100 0x10 0x100 rw test") 

print(exec_func(1, DEFAULT_ARG, log=True))
print(exec_func(4, DEFAULT_ARG, log=True))
print(exec_func(9, DEFAULT_ARG, log=True))

Every function is called func<n> where <n> is the number of the function, and is called as func<n>(s=copy of argv[1], 0x10=length of s).

So I mapped a file (plaintext) to the address 0x100 in rw mode with o+ ./plaintext 0x100 (alternative way was to use om). Then to execute (emulate):

s sym.func<n>  # seek to function
aei            # init esil state
aeim           # initialize esil memory
aeip           # set rip = current seek
aer rdi=0x100  # set rdi
aer rsi=0x10   # set rsi
aecu <address> # continue execution until <address>

In the latest radare2 version is also possible to set stack memory with argv and envp: https://twitter.com/trufae/status/1392886678180728832

Output example:

BEFORE func1
0x00000100  7074 6d7b 7465 7374 6162 6364 6566 677d  ptm{testabcdefg}

AFTER func1
0x00000100  747b 6d70 6574 7374 6264 6361 667d 6765  t{mpetstbdcaf}ge

747b6d706574737462646361667d6765

BEFORE func4
0x00000100  7074 6d7b 7465 7374 6162 6364 6566 677d  ptm{testabcdefg}

AFTER func4
0x00000100  7074 6d7b 7465 7374 1511 0610 1e0b 130d  ptm{test........

70746d7b74657374151106101e0b130d

BEFORE func9
0x00000100  7074 6d7b 7465 7374 6162 6364 6566 677d  ptm{testabcdefg}

AFTER func9
0x00000100  9a9e 97a5 9e8f 9d9e 8b8c 8d8e 8f90 91a7  ................

9a9e97a59e8f9d9e8b8c8d8e8f9091a7

Then we have written every inverse:

def func1_inv(s):
    for i in range(0, len(s), 4):
        s[i+1], s[i+3], s[i+2], s[i] = s[i], s[i+1], s[i+2], s[i+3]

def func2_inv(s):
    for i in range(8):
        s[i] ^= s[16 - i - 1]

def func3_inv(s):
    for i in range(8):
        s[i] = (((s[i] >> i) & 0xff) | (s[i] << (8 - i)) & 0xff)
    for i in range(8):
        s[i + 8] = (((s[i + 8] >> i) & 0xff) | (s[i + 8] << (8 - i)) & 0xff)

def func4_inv(s):
    for i in range(8):
        s[16 - i - 1] ^= s[i]

def func5_inv(s):
    to_xor = b"{reverse_fake_flag}ptm"[:16]
    for i in range(16):
        s[i] ^= to_xor[i]

def func6_inv(s):
    for i in range(16):
        s[i] = (~s[i] & 0xff)

def func7_inv(s):
    for i in range(16):
        s[i] = int(bin(s[i])[2:].zfill(8)[::-1], 2)

def func8_inv(s):
    for i in range(8):
        s[16 - i - 1], s[i] = s[i], s[16 - i - 1]

def func9_inv(s):
    for i in range(16):
        s[i] = (s[i] + 256 - 42) % 256

def func10_inv(s):
    for i in range(16):
        s[i] = (s[i] << 4) & 0xff | (s[i] >> 4) & 0xf

# We then memoized this function
def func11_inv(s):
    for i in range(16):
        cur = bin(s[i])[2:].zfill(8)
        s[i] = int(cur[0] + cur[4] + cur[1] + cur[5] + cur[2] + cur[6] + cur[3] + cur[7], 2)

def func12_inv(s):
    for i in range(16):
        s[i] = (~s[i] & 0xff) ^ i

def func13_inv(s):
    for i in range(16):
        s[i] = (s[i] + 256 - i) % 256

def func14_inv(s):
    for i in range(16):
        if ord('z') >= s[i] >= ord('a'):
            s[i] -= 32
        elif ord('Z') >= s[i] >= ord('A'):
            s[i] += 32

funcs_inv = [
    lambda x: x,
    func1_inv,
    func2_inv,
    func3_inv,
    func4_inv,
    func5_inv,
    func6_inv,
    func7_inv,
    func8_inv,
    func9_inv,
    func10_inv,
    func11_inv,
    func12_inv,
    func13_inv,
    func14_inv
]

And a test to confirm the correctness:

for f in range(1, 15):
    print(f"TESTING func{f}")
    for i in range(256):
        msg = bytearray([i for _ in range(16)])
        ret = exec_func(f, log=False, arg=msg)
        x = bytearray(bytes.fromhex(ret))
        funcs_inv[f](x)
        assert x == msg

Second step

#!/usr/bin/env python3

import itertools as it

ciphertext = bytes.fromhex("2aad2e5a49fb2d9adb908dd00eb48c8a6607ab619f75b0272f3c1eb33fe9edaf")
ciphertext1 = ciphertext[:16]
ciphertext2 = ciphertext[16:]
counter = 0
for p in it.permutations([x for x in range(1, 15)], 7):
    if counter % 100000 == 0:
        print(counter)
    s = bytearray(ciphertext1)
    for i in p:
        funcs_inv[i](s)
    s1 = bytearray(ciphertext2)
    for i in p[::-1]:
        funcs_inv[i](s1)
    if s1 == s:
        print(p)
        exit(0)
    counter += 1

With pypy3 it took around two minute. Found the permutation: (8, 9, 14, 4, 10, 12, 7)

Third step

c = bytearray(ciphertext1)
for i in [8, 9, 14, 4, 10, 12, 7]:
    funcs_inv[i](c)

possibilities = [1, 2, 3, 5, 6, 11, 13]
for p in it.permutations(possibilities):
    enc = c[:]
    for i in p:
        funcs_inv[i](enc)
    if enc.startswith(b"ptm"):
        print(enc)

Exploit

#!/usr/bin/env python3

import itertools as it

def xor(a, b):
    return bytes(x^y for x,y in zip(a,b))

def func1_inv(s):
    for i in range(0, len(s), 4):
        s[i+1], s[i+3], s[i+2], s[i] = s[i], s[i+1], s[i+2], s[i+3]

def func2_inv(s):
    for i in range(8):
        s[i] ^= s[16 - i - 1]

def func3_inv(s):
    for i in range(8):
        s[i] = (((s[i] >> i) & 0xff) | (s[i] << (8 - i)) & 0xff)
    for i in range(8):
        s[i + 8] = (((s[i + 8] >> i) & 0xff) | (s[i + 8] << (8 - i)) & 0xff)

def func4_inv(s):
    for i in range(8):
        s[16 - i - 1] ^= s[i]

def func5_inv(s):
    to_xor = b"{reverse_fake_flag}ptm"[:16]
    for i in range(16):
        s[i] ^= to_xor[i]

def func6_inv(s):
    for i in range(16):
        s[i] = (~s[i] & 0xff)

def func7_inv(s):
    for i in range(16):
        s[i] = int(bin(s[i])[2:].zfill(8)[::-1], 2)

def func8_inv(s):
    for i in range(8):
        s[16 - i - 1], s[i] = s[i], s[16 - i - 1]

def func9_inv(s):
    for i in range(16):
        s[i] = (s[i] + 256 - 42) % 256

def func10_inv(s):
    for i in range(16):
        s[i] = (s[i] << 4) & 0xff | (s[i] >> 4) & 0xf

d11_inv = [0, 1, 4, 5, 16, 17, 20, 21, 64, 65, 68, 69, 80, 81, 84, 85, 2, 3, 6, 7, 18, 19, 22, 23, 66, 67, 70, 71, 82, 83, 86, 87, 8, 9, 12, 13, 24, 25, 28, 29, 72, 73, 76, 77, 88, 89, 92, 93, 10, 11, 14, 15, 26, 27, 30, 31, 74, 75, 78, 79, 90, 91, 94, 95, 32, 33, 36, 37, 48, 49, 52, 53, 96, 97, 100, 101, 112, 113, 116, 117, 34, 35, 38, 39, 50, 51, 54, 55, 98, 99, 102, 103, 114, 115, 118, 119, 40, 41, 44, 45, 56, 57, 60, 61, 104, 105, 108, 109, 120, 121, 124, 125, 42, 43, 46, 47, 58, 59, 62, 63, 106, 107, 110, 111, 122, 123, 126, 127, 128, 129, 132, 133, 144, 145, 148, 149, 192, 193, 196, 197, 208, 209, 212, 213, 130, 131, 134, 135, 146, 147, 150, 151, 194, 195, 198, 199, 210, 211, 214, 215, 136, 137, 140, 141, 152, 153, 156, 157, 200, 201, 204, 205, 216, 217, 220, 221, 138, 139, 142, 143, 154, 155, 158, 159, 202, 203, 206, 207, 218, 219, 222, 223, 160, 161, 164, 165, 176, 177, 180, 181, 224, 225, 228, 229, 240, 241, 244, 245, 162, 163, 166, 167, 178, 179, 182, 183, 226, 227, 230, 231, 242, 243, 246, 247, 168, 169, 172, 173, 184, 185, 188, 189, 232, 233, 236, 237, 248, 249, 252, 253, 170, 171, 174, 175, 186, 187, 190, 191, 234, 235, 238, 239, 250, 251, 254, 255]

def func11_inv(s):
    for i in range(16):
        s[i] = d11_inv[s[i]]

def func12_inv(s):
    for i in range(16):
        s[i] = (~s[i] & 0xff) ^ i

def func13_inv(s):
    for i in range(16):
        s[i] = (s[i] + 256 - i) % 256

def func14_inv(s):
    for i in range(16):
        if ord('z') >= s[i] >= ord('a'):
            s[i] -= 32
        elif ord('Z') >= s[i] >= ord('A'):
            s[i] += 32

funcs_inv = [
    lambda x: x,
    func1_inv,
    func2_inv,
    func3_inv,
    func4_inv,
    func5_inv,
    func6_inv,
    func7_inv,
    func8_inv,
    func9_inv,
    func10_inv,
    func11_inv,
    func12_inv,
    func13_inv,
    func14_inv
]

ciphertext = bytes.fromhex("2aad2e5a49fb2d9adb908dd00eb48c8a6607ab619f75b0272f3c1eb33fe9edaf")
ciphertext1 = ciphertext[:16]
ciphertext2 = ciphertext[16:]
#  counter = 0
#  for p in it.permutations([x for x in range(1, 15)], 7):
#      if counter % 100000 == 0:
#          print(counter)
#      s = bytearray(ciphertext1)
#      for i in p:
#          funcs_inv[i](s)
#      s1 = bytearray(ciphertext2)
#      for i in p[::-1]:
#          funcs_inv[i](s1)
#      if s1 == s:
#          print(p)
#          exit(0)
#      counter += 1

c = bytearray(ciphertext1)
for i in [8, 9, 14, 4, 10, 12, 7]:
    funcs_inv[i](c)

possibilities = [1, 2, 3, 5, 6, 11, 13] # (8, 9, 14, 4, 10, 12, 7)
for p in it.permutations(possibilities):
    enc = c[:]
    for i in p:
        funcs_inv[i](enc)
    if enc.startswith(b"ptm"):
        print(enc)

Flag

ptm{brut3_f0rc3}