Confusion
SrdnlenCTF 2025 : Cryptography
#!/usr/bin/env python3
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
import os
# Local imports
FLAG = os.getenv("FLAG", "srdnlen{REDACTED}").encode()
# Server encryption function
def encrypt(msg, key):
pad_msg = pad(msg, 16)
blocks = [os.urandom(16)] + [pad_msg[i:i + 16] for i in range(0, len(pad_msg), 16)]
b = [blocks[0]]
for i in range(len(blocks) - 1):
tmp = AES.new(key, AES.MODE_ECB).encrypt(blocks[i + 1])
b += [bytes(j ^ k for j, k in zip(tmp, blocks[i]))]
c = [blocks[0]]
for i in range(len(blocks) - 1):
c += [AES.new(key, AES.MODE_ECB).decrypt(b[i + 1])]
ct = [blocks[0]]
for i in range(len(blocks) - 1):
tmp = AES.new(key, AES.MODE_ECB).encrypt(c[i + 1])
ct += [bytes(j ^ k for j, k in zip(tmp, c[i]))]
return b"".join(ct)
KEY = os.urandom(32)
print("Let's try to make it confusing")
flag = encrypt(FLAG, KEY).hex()
print(f"|\n| flag = {flag}")
while True:
print("|\n| ~ Want to encrypt something?")
msg = bytes.fromhex(input("|\n| > (hex) "))
plaintext = pad(msg + FLAG, 16)
ciphertext = encrypt(plaintext, KEY)
print("|\n| ~ Here is your encryption:")
print(f"|\n| {ciphertext.hex()}")
In this challenge we have been given an oracle for a custom encryption using AES as a subroutine to encrypt our message appended with the flag.
Notation wise: and
Analysing the encrypt function gives the following pattern:
Note that here for some random . Now writing everything in terms of we get:
Also we are given the ciphertext for the flag along with the encryption oracle, say . To find the decryption of the flag’s ct we need to somehow use the oracle to find decryptions of messages. Looking at the final form of ct, the fourth term is the one which is the most helpful. Since we can encrypt arbitrarily long messages, we have complete control over and . Also notice that occurs only once in the term so we can substitute it with an unknown term without worrying about it causing any unnecessary side effects. This motivates the idea to extract the flag: substitute , , . This will give us the output, . Then taking the xor of and will give us a block of the flag. Note that we can find by sending the message as (a full block of zeroes) and taking the second term of the ciphertext. To find out however, we have to do a little more work.
For just take the second term of encryption of flag. For , send (a block of zeroes) and the 4th term of the ct collapses to , since and For subsequent , we use the following inductive procedure:
- Send and extract
- Extract
- Calculate
This sums up the protocol.
from pwn import *
# context.log_level = 'debug'
# io = process('./chall.py')
io = remote('confusion.challs.srdnlen.it', 1338)
io.recvuntil(b'flag = ')
flag = io.recvline().strip().decode()
log.info(f'flag: {flag}')
af1 = flag[32:64]
r5 = flag[96:128]
r7 = flag[128:160]
io.sendlineafter(b'> (hex) ', b'00'*32)
io.recvuntil(b'Here is your encryption:')
io.recvline()
io.recvuntil(b'| ')
ct = io.recvline().strip().decode()
log.info(f'ct: {ct}')
a0 = ct[32:64]
payload = int(a0, 16) ^ int(af1, 16)
io.sendlineafter(b'> (hex) ', hex(payload)[2:].encode()+b'00'*16)
io.recvuntil(b'Here is your encryption:')
io.recvline()
io.recvuntil(b'| ')
ct = io.recvline().strip().decode()
log.info(f'ct: {ct}')
x = ct[96:128]
f1 = int(x, 16) ^ int(af1, 16)
f1 = hex(f1)[2:]
log.info(f'f1: {bytes.fromhex(f1).decode()}')
io.sendlineafter(b'> (hex) ', b'00'*16)
io.recvuntil(b'Here is your encryption:')
io.recvline()
io.recvuntil(b'| ')
ct = io.recvline().strip().decode()
log.info(f'ct: {ct}')
af2 = ct[96:128]
payload2 = int(af2,16)^int(af1,16)
io.sendlineafter(b'> (hex) ', b'00'*16 + hex(payload2)[2:].encode())
io.recvuntil(b'Here is your encryption:')
io.recvline()
io.recvuntil(b'| ')
ct = io.recvline().strip().decode()
log.info(f'ct: {ct}')
x = ct[128:160]
f2 = int(x, 16) ^ int(af2, 16) ^ int(f1, 16)
f2 = hex(f2)[2:]
log.info(f'f2: {bytes.fromhex(f2).decode()}')
io.sendlineafter(b'> (hex) ', (f1+f2).encode())
io.recvuntil(b'Here is your encryption:')
io.recvline()
io.recvuntil(b'| ')
ct = io.recvline().strip().decode()
log.info(f'ct: {ct}')
r3 = ct[96:128]
af3 = int(r3, 16) ^ int(r5, 16) ^ int(af1, 16)
payload3 = af3 ^ int(a0, 16)
payload3 = hex(payload3)[2:]
io.sendlineafter(b'> (hex) ', payload3.encode()+b'00'*32)
io.recvuntil(b'Here is your encryption:')
io.recvline()
io.recvuntil(b'| ')
ct = io.recvline().strip().decode()
log.info(f'ct: {ct}')
r3 = ct[96:128]
f3 = int(r3, 16) ^ int(a0, 16)
f3 = hex(f3)[2:]
log.info(f'f3: {bytes.fromhex(f3).decode()}')
io.sendlineafter(b'> (hex) ', (f2+f3).encode())
io.recvuntil(b'Here is your encryption:')
io.recvline()
io.recvuntil(b'| ')
ct = io.recvline().strip().decode()
log.info(f'ct: {ct}')
r3 = ct[96:128]
af4 = int(r3, 16) ^ int(r7, 16) ^ int(af1, 16)
payload3 = af4 ^ int(a0, 16)
payload3 = hex(payload3)[2:]
io.sendlineafter(b'> (hex) ', payload3.encode()+b'00'*32)
io.recvuntil(b'Here is your encryption:')
io.recvline()
io.recvuntil(b'| ')
ct = io.recvline().strip().decode()
log.info(f'ct: {ct}')
r3 = ct[96:128]
f4 = int(r3, 16) ^ int(a0, 16)
f4 = hex(f4)[2:]
log.info(f'f3: {bytes.fromhex(f4).decode()}')
flag = f1+f2+f3+f4
print(f'flag: {bytes.fromhex(flag).decode()}')