Google CTF 2021 Writeup

Mon Jul 19 2021

I participated in Google CTF 2021 as a member of WreckTheLine. The result was 12th/379 (within teams with positive points). In this article I introduce writeups for three crypto challenges I solved.

crypto

PYTHIA

65 solves

server.py
#!/usr/bin/python -u
import random
import string
import time

from base64 import b64encode, b64decode
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt

max_queries = 150
query_delay = 10

passwords = [bytes(''.join(random.choice(string.ascii_lowercase) for _ in range(3)), 'UTF-8') for _ in range(3)]
flag = open("flag.txt", "rb").read()

def menu():
    print("What you wanna do?")
    print("1- Set key")
    print("2- Read flag")
    print("3- Decrypt text")
    print("4- Exit")
    try:
        return int(input(">>> "))
    except:
        return -1

print("Welcome!\n")

key_used = 0

for query in range(max_queries):
    option = menu()

    if option == 1:
        print("Which key you want to use [0-2]?")
        try:
            i = int(input(">>> "))
        except:
            i = -1
        if i >= 0 and i <= 2:
          key_used = i
        else:
          print("Please select a valid key.")
    elif option == 2:
        print("Password?")
        passwd = bytes(input(">>> "), 'UTF-8')

        print("Checking...")
        # Prevent bruteforce attacks...
        time.sleep(query_delay)
        if passwd == (passwords[0] + passwords[1] + passwords[2]):
            print("ACCESS GRANTED: " + flag.decode('UTF-8'))
        else:
            print("ACCESS DENIED!")
    elif option == 3:
        print("Send your ciphertext ")

        ct = input(">>> ")
        print("Decrypting...")
        # Prevent bruteforce attacks...
        time.sleep(query_delay)
        try:
            nonce, ciphertext = ct.split(",")
            nonce = b64decode(nonce)
            ciphertext = b64decode(ciphertext)
        except:
            print("ERROR: Ciphertext has invalid format. Must be of the form \"nonce,ciphertext\", where nonce and ciphertext are base64 strings.")
            continue

        kdf = Scrypt(salt=b'', length=16, n=2**4, r=8, p=1, backend=default_backend())
        key = kdf.derive(passwords[key_used])
        try:
            cipher = AESGCM(key)
            plaintext = cipher.decrypt(nonce, ciphertext, associated_data=None)
        except:
            print("ERROR: Decryption failed. Key was not correct.")
            continue

        print("Decryption successful")
    elif option == 4:
        print("Bye!")
        break
    else:
        print("Invalid option!")
    print("You have " + str(max_queries - query) + " trials left...\n")

This challenge is about AES's GCM mode.

The three passwords are in the form of [a-z][a-z][a-z] (so 26326^3 patterns). We have to leak these passwords to get the flag. What we can do is only query ciphertext-tag pair and nonce, which will be decrypted by AES (GCM mode) using the key generated by KDF(password), and know whether the decryption succeeds or not. We can query around 50 times for each password. In order to leak the password, we need ciphertext-tag pairs which can be decrypted by some keys, and will fail by the other keys.

In GCM, tag calculation is done in GF(2)GF(2) as follows:

tag=C0Hn+1+C1Hn++Cn1H2+LH+Stag = C_0 H^{n+1} + C_1 H^{n} + \cdots + C_{n-1} H^{2} + L H + S

where H=EncK(0)H = Enc_K(0), S=EncK(iv1)S = Enc_K(\mathrm{iv} || 1), L=(0bitlen(C))L = (0 || \mathrm{bitlen(C)}). Remark that associated data is omitted for simplicity because all associated_data are None in this challenge.

Since H,SH, S are different in each AES's key_j (I denote them as Hj,Sj(0j<m)H_j, S_j (0 \le j < m)), the decryption of ciphertext-tag pairs by key_j (0j<m)(0\le j < m) will succeed if we choose Ci,tagC_i, tag such that

(H0n+1H0211Hm1n+1Hm121)(C0Cn1tag)=(LH0S0LHm2Sm2LHm1Sm1)\left( \begin{matrix} H_0^{n+1} & \cdots & H_0^2 & -1 \\ \vdots & \ddots & \vdots & -1 \\ H_{m-1}^{n+1} & \cdots & H_{m-1}^2 & -1 \\ \end{matrix} \right) \left( \begin{matrix} C_0 \\ \vdots \\ C_{n-1} \\ tag \end{matrix} \right) = \left( \begin{matrix} -LH_0 - S_0 \\ \vdots \\ -LH_{m-2} - S_{m-2} \\ -LH_{m-1} - S_{m-1} \\ \end{matrix} \right)

This means that we can determine Ci,tagC_i, tag uniquely when n=m1n = m - 1.

Theoretically we can find the correct password by bisection method (around log226314\log_2 26^3 \simeq 14 times). But if MM is too large (around 1,000) the above linear equation takes too much time. So I first found where password is from [0,M),[M,2M),[263//M×M,263)[0, M), [M, 2M) \cdots, [26^3//M\times M, 26^3), after that I solved the password by bisection method.

solve.sage
import string
from base64 import b64encode
from itertools import product

from Crypto.Cipher import AES
from Crypto.Util.number import bytes_to_long, long_to_bytes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
from pwn import remote


X = GF(2).polynomial_ring().gen()
poly = X ** 128 + X ** 7 + X ** 2 + X ** 1 + 1
F = GF(2 ** 128, name='a', modulus=poly)
R.<x> = PolynomialRing(F)


def tobin(x, n):
    x = Integer(x)
    nbits = x.nbits()
    assert nbits <= n
    return [0] * (n - nbits) + x.bits()[:: -1]


def frombin(v):
    return int("".join(map(str, v)), 2)


def toF(x):
    x = frombin(tobin(x, 128)[:: -1])
    return F.fetch_int(x)


def fromF(x):
    x = x.integer_representation()
    x = frombin(tobin(x, 128)[:: -1])
    return x


# nonce (iv) will be fixed as follows:
iv = b"\x01" * 12

pass_to_key = {}
pass_to_H = {}
pass_to_S = {}

for c_list in product(string.ascii_lowercase, repeat=3):
    s = "".join(c_list)
    kdf = Scrypt(salt=b"", length=16, n=2 ** 4, r=8, p=1, backend=default_backend())
    key = kdf.derive(s.encode())
    pass_to_key[s] = key

    cipher = AES.new(key, mode=AES.MODE_ECB)
    H = toF(bytes_to_long(cipher.encrypt(b"\x00" * 16)))
    S = toF(bytes_to_long(cipher.encrypt(iv + b"\x00\x00\x00\x01")))

    pass_to_H[s] = H
    pass_to_S[s] = S

pass_list = list(pass_to_key.keys())
H_list = [pass_to_H[p] for p in pass_list]
S_list = [pass_to_S[p] for p in pass_list]

N = 26 ** 3


def make_payload(start_idx, end_idx):
    M = end_idx - start_idx
    L = toF(int(f"%016x%016x" % (0, 8*16*(M-1)), 16))
    C = matrix(F, M, M)
    T = vector(F, M)
    for i in range(M):
        H = H_list[start_idx+i]
        S = S_list[start_idx+i]
        tmp = H ** 2
        for j in range(M-1):
            C[i, M-2-j] = tmp
            tmp *= H
        C[i, M-1] = 1
        T[i] = L * H + S
    c_list = C.solve_right(T)
    ct = b""
    for c in c_list:
        ct += long_to_bytes(fromF(c)).rjust(16, b"\x00")
    payload = f"{b64encode(iv).decode()},{b64encode(ct).decode()}"
    return payload


def find_key(M=400):
    start_idx, end_idx = find_from_batch(M=M)
    ans_idx = find_from_bisect(start_idx, end_idx)
    return pass_list[ans_idx]


def find_from_bisect(start_idx, end_idx):
    if start_idx == end_idx - 1:
        return start_idx
    M = (end_idx - start_idx) // 2
    payload = make_payload(start_idx, start_idx + M)
    _r.sendlineafter(">>> ", "3")
    _r.sendlineafter(">>> ", payload)
    _ = _r.recvline()
    ret = _r.recvline().strip().decode()
    if "ERROR" not in ret:
        return find_from_bisect(start_idx, start_idx + M)
    else:
        return find_from_bisect(start_idx + M, end_idx)


def find_from_batch(M=400):
    for start_idx in range(0, N, M):
        if start_idx == N // M * M:
            break
        payload = make_payload(start_idx, start_idx + M)
        # deel with the previous round (in order to calculate a payload in query_delay time)
        if start_idx != 0:
            ret = _r.recvline().strip().decode()
            if "ERROR" not in ret:
                return start_idx - M, start_idx
        _r.sendlineafter(">>> ", "3")
        _r.sendlineafter(">>> ", payload)
        _ = _r.recvline()
    ret = _r.recvline().strip().decode()
    if "ERROR" not in ret:
        return start_idx - M, start_idx
    return start_idx, N


keys = ""
_r = remote("pythia.2021.ctfcompetition.com", 1337)
for i in range(3):
    _r.sendlineafter(">>> ", "1")
    _r.sendlineafter(">>> ", str(i))
    key = find_key(M=400)
    keys += key
_r.sendlineafter(">>> ", "2")
_r.sendlineafter(">>> ", keys)
_ = _r.recvline()
print(_r.recvline())
_r.close()

CTF{gCm_1s_n0t_v3ry_r0bust_4nd_1_sh0uld_us3_s0m3th1ng_els3_h3r3}

TIRAMISU

28 solves

This challenge is about ECDH. We are given a source code of a server which first exchange keys and echoes sent messages encrypted by a shared key some times.

The vulnerable point is as follows:

server.go
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

func (server *Server) EstablishChannel(clientHello *pb.ClientHello) error {
	// Load peer key.
	peer, err := proto2ecdsaKey(clientHello.Key)
	if err != nil {
		return err
	}

	// Key sanity checks.
	if !peer.Curve.IsOnCurve(peer.X, peer.Y) {
		return fmt.Errorf("point (%X, %X) not on curve", peer.X, peer.Y)
	}

	// Compute shared secret.
	P := server.key.Params().P
	D := server.key.D.Bytes()
	sharedX, _ := server.key.ScalarMult(new(big.Int).Mod(peer.X, P), new(big.Int).Mod(peer.Y, P), D)

	masterSecret := make([]byte, server.key.Params().BitSize/8)
	sharedX.FillBytes(masterSecret)

	// Derive AES+MAC session keys.
	server.channel, err = newAuthCipher(masterSecret, channelCipherKdfInfo, channelMacKdfInfo)
	if err != nil {
		return fmt.Errorf("newAuthCipher()=%v, want nil error", err)
	}
	return nil
}
pb_util.go
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

func proto2ecdsaKey(key *pb.EcdhKey) (*ecdsa.PublicKey, error) {
	out := &ecdsa.PublicKey{
		X: new(big.Int).SetBytes(key.Public.X),
		Y: new(big.Int).SetBytes(key.Public.Y),
	}
	switch key.Curve {
	case pb.EcdhKey_SECP224R1:
		out.Curve = elliptic.P224()
	case pb.EcdhKey_SECP256R1:
		out.Curve = elliptic.P256()
	default:
		return nil, fmt.Errorf("unsupported curve id %d", key.Curve)
	}
	return out, nil
}

We can choose a different elliptic curve from one which will be used to generate a shared key. So invalid curve attack can be applicable.

First, I found some points which were on elliptic.P224 and whose order on group arithmetic of elliptic.P256 were small. Remark that the points are not necessarily on elliptic.P256. If we could find them, we can recover the secret key by Chinese remainder theorem. Since the group arithmetic of Elliptic curve y2=x3+ax+by^2 = x^3 + ax + b doesn't depend on bb, we can find brute-force a curve whose order is divided by small prime by changing bb from elliptic.P256's one. When such a curve is found, the point whose order is small prime can be found easily (by generating random point and multiply an appropriate integer.

find_points.sage
import pickle

for i in range(50):
    print(i)
    try:
        tmpEC1 = EllipticCurve(GF(p1), [a1, i])
        for p, num in factor(tmpEC1.order()):
            prime_to_index[p].append(i)
    except ArithmeticError:
        continue

primes = sorted(prime_to_index.keys())
prime_to_xy = {}

for prime in primes:
    print(prime)
    idx = prime_to_index[prime][0]
    tmpEC1 = EllipticCurve(GF(p1), [a1, idx])
    while True:
        tmp = tmpEC1.random_point()
        if (prime, 1) in tmp.order().factor():
            break
    tmp = tmp * (tmp.order() // prime)
    assert tmp.order() == prime
    x2 = int(tmp.xy()[0])
    while True:
        try:
            hoge = EC2.lift_x(Integer(x2))
            y2 = int(hoge.xy()[1])
            break
        except ValueError:
            x2 += p1
            continue
    y1 = int(tmp.xy()[1])
    i = int((int(y1) - int(y2)) * pow(p2, -1, p1))
    y2_payload = int(y2) + i * p2
    prime_to_xy[int(prime)] = (int(x2), int(y2_payload))

with open("./prime_to_xy.pickle", "wb") as f:
    pickle.dump(prime_to_xy, f)

Second, I found dmodpd \mod p for each pp where dd is a secret key.

find_residual.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import argparse
import pwnlib
import challenge_pb2
import struct
import sys

from Crypto.Util.number import long_to_bytes
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.asymmetric import ec

CHANNEL_CIPHER_KDF_INFO = b"Channel Cipher v1.0"
CHANNEL_MAC_KDF_INFO = b"Channel MAC v1.0"

FLAG_CIPHER_KDF_INFO = b"Flag Cipher v1.0"
FLAG_MAC_KDF_INFO = b"Flag MAC v1.0"

IV = b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff"

with open("./prime_to_xy.pickle", "rb") as f:
    prime_to_xy = pickle.load(f)

a1 = -3
b1 = 0xB4050A850C04B3ABF54132565044B0B7D7BFD8BA270B39432355FFB4
p1 = 26959946667150639794667015087019630673557916260026308143510066298881


class AuthCipher(object):
    def __init__(self, secret, cipher_info, mac_info):
        # print(secret)
        self.cipher_key = self.derive_key(secret, cipher_info)
        # print(self.cipher_key)
        self.mac_key = self.derive_key(secret, mac_info)

    def derive_key(self, secret, info):
        hkdf = HKDF(algorithm=hashes.SHA256(), length=16, salt=None, info=info,)
        return hkdf.derive(secret)

    def encrypt(self, iv, plaintext):
        cipher = Cipher(algorithms.AES(self.cipher_key), modes.CTR(iv))
        encryptor = cipher.encryptor()
        ct = encryptor.update(plaintext) + encryptor.finalize()

        h = hmac.HMAC(self.mac_key, hashes.SHA256())
        h.update(iv)
        h.update(ct)
        mac = h.finalize()

        out = challenge_pb2.Ciphertext()
        out.iv = iv
        out.data = ct
        out.mac = mac
        return out


def handle_pow(tube):
    raise NotImplemented()


def read_message(tube, typ):
    n = struct.unpack("<L", tube.recvnb(4))[0]
    buf = tube.recvnb(n)
    msg = typ()
    msg.ParseFromString(buf)
    return msg


def write_message(tube, msg):
    buf = msg.SerializeToString()
    tube.send(struct.pack("<L", len(buf)))
    tube.send(buf)


def curve2proto(c):
    assert c.name == "secp224r1"
    return challenge_pb2.EcdhKey.CurveID.SECP224R1


def key2proto(key):
    assert isinstance(key, ec.EllipticCurvePublicKey)
    out = challenge_pb2.EcdhKey()
    out.curve = curve2proto(key.curve)
    x, y = key.public_numbers().x, key.public_numbers().y
    out.public.x = x.to_bytes((x.bit_length() + 7) // 8, "big")
    out.public.y = y.to_bytes((y.bit_length() + 7) // 8, "big")
    return out


def proto2key(key):
    assert isinstance(key, challenge_pb2.EcdhKey)
    assert key.curve == challenge_pb2.EcdhKey.CurveID.SECP224R1
    curve = ec.SECP224R1()
    x = int.from_bytes(key.public.x, "big")
    y = int.from_bytes(key.public.y, "big")
    public = ec.EllipticCurvePublicNumbers(x, y, curve)
    return ec.EllipticCurvePublicKey.from_encoded_point(curve, public.encode_point())


class Point:
    def __init__(self, coordinate, p, params):
        assert len(coordinate) == 2
        assert len(params) == 2
        x, y = coordinate
        self.x = x
        self.y = y
        self.p = p
        A, B = params
        self.A = A
        self.B = B

    def __add__(self, other):
        if other.x == 0 and other.y == 0:
            return self
        if self.x == 0 and self.y == 0:
            return other
        if self.x == other.x and (self.y + other.y) % self.p == 0:
            return Point((0, 0), self.p, [self.A, self.B])
        if self.x == other.x:
            _lambda = int(
                (3 * self.x ** 2 + self.A) * pow(2 * self.y, -1, self.p) % self.p
            )
        else:
            _lambda = int(
                (other.y - self.y) * pow(other.x - self.x, -1, self.p) % self.p
            )
        x3 = int((_lambda ** 2 - self.x - other.x) % self.p)
        y3 = int((_lambda * (self.x - x3) - self.y) % self.p)
        return Point((x3, y3), self.p, [self.A, self.B])

    def double(self):
        if (2 * self.y) % self.p == 0:
            return Point((0, 0), self.p, [self.A, self.B])
        if self.x == 0 and self.y == 0:
            return self
        _lambda = int((3 * self.x ** 2 + self.A) * pow(2 * self.y, -1, self.p) % self.p)
        x3 = int((_lambda ** 2 - 2 * self.x) % self.p)
        y3 = int((_lambda * (self.x - x3) - self.y) % self.p)
        return Point((x3, y3), self.p, [self.A, self.B])

    def __mul__(self, d):
        if d == 0:
            return Point((0, 0), self.p, [self.A, self.B])
        elif d == 1:
            return self
        elif d % 2 == 1:
            return self + self * (d - 1)
        else:
            return self.double() * (d // 2)

    def __repr__(self):
        return f"({self.x}, {self.y})"


def run_session(host, port, _reversed):
    primes = sorted(prime_to_xy.keys())
    if _reversed:
        primes = primes[::-1]
    for prime in primes:
        if prime in prime_to_res:
            continue
        if prime < 1000:
            continue
        if prime > 20000:
            continue
        print(prime)
        tube = pwnlib.tubes.remote.remote(host, port)
        tube.recvuntil("== proof-of-work: ")
        if tube.recvline().startswith(b"enabled"):
            handle_pow()

        server_hello = read_message(tube, challenge_pb2.ServerHello)
        iv = server_hello.encrypted_flag.iv
        data = server_hello.encrypted_flag.data
        mac = server_hello.encrypted_flag.mac
        print(iv, data, mac)

        private_key = ec.generate_private_key(ec.SECP224R1())
        client_hello = challenge_pb2.ClientHello()
        payload = key2proto(private_key.public_key())
        x, y = prime_to_xy[prime]
        payload.curve = 2
        payload.public.x = x.to_bytes((x.bit_length() + 7) // 8, "big")
        payload.public.y = y.to_bytes((y.bit_length() + 7) // 8, "big")
        client_hello.key.CopyFrom(payload)

        write_message(tube, client_hello)

        point = Point([x % p1, y % p1], p1, [a1, b1])
        for i in range(prime):
            tmp_point = point * i
            shared_key = long_to_bytes(tmp_point.x).rjust(28, b"\x00")

            channel = AuthCipher(
                shared_key, CHANNEL_CIPHER_KDF_INFO, CHANNEL_MAC_KDF_INFO
            )
            msg = challenge_pb2.SessionMessage()
            msg.encrypted_data.CopyFrom(channel.encrypt(IV, b"hello"))
            write_message(tube, msg)

            reply = read_message(tube, challenge_pb2.SessionMessage)
            if len(reply.encrypted_data.iv) > 0:
                print("found!")
                print(prime, i)
                prime_to_res[prime] = i
                break
        else:
            print("not found... bug!!!!")

        tube.close()

    with open("prime_to_res.pickle", "wb") as f:
        pickle.dump(prime_to_res, f)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--host", metavar="H", type=str, default="127.0.0.1", help="challenge #host"
    )
    parser.add_argument(
        "--port", metavar="P", type=int, default=1337, help="challenge #port"
    )
    parser.add_argument("--reversed", action="store_true", help="challenge #port")
    args = parser.parse_args()

    run_session(args.host, args.port, args.reversed)

    return 0


if __name__ == "__main__":
    sys.exit(main())

Note that I used a prime pp such that 1000p200001000 \le p \le 20000. If pp is too big, it takes too much time to find a residual. If pp is too small the problem happens, which I'll explain later.

Finally, I had many pairs (ri,pi)(r_i, p_i) such that d=rimodpid = r_i \mod p_i. But since the xGxG's xx coordinate is equal to (nx)G(n-x)G's one (nn is the order), we cannot distinguish rir_i and nrin-r_i essentially, so I need brute-force. To search all spaces, the number of pairs (ri,pi)(r_i, p_i) should be smaller than around 20-30. That's why pp should not be too small.

solve.sage
import pickle

from Crypto.Util.number import long_to_bytes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes

with open("./prime_to_res.pickle", "rb") as f:
    prime_to_res = pickle.load(f)


class AuthCipher(object):
    def __init__(self, secret, cipher_info, mac_info):
        self.cipher_key = self.derive_key(secret, cipher_info)

    def derive_key(self, secret, info):
        hkdf = HKDF(algorithm=hashes.SHA256(), length=16, salt=None, info=info,)
        return hkdf.derive(secret)


FLAG_CIPHER_KDF_INFO = b"Flag Cipher v1.0"
FLAG_MAC_KDF_INFO = b"Flag MAC v1.0"

iv = b"s@v\xd5g\xe0\t*\xbc\xe1\t\x15\x82UC}"
data = b'>}"B\xea"WgA\x9c*\x0cp\xd6b\\O6\xfc\xa8\x8fK\xe3\xdcU\xfc\xaa~\xb7\x16\xd5\x8aJ\xcf8M\xec{q\x99\x81\xc8\xe9yyj`3_\x94^\xcb\x84P\x80\xd3\x9b='

primes = list(prime_to_res.keys())[-20:]
b_list = primes
a_list = list(prime_to_res[prime] for prime in primes)
a_inv_list = [b - a for a, b in zip(a_list, b_list)]

for i in range(2 ** len(b_list)):
    n = i
    tmp = []
    for j in range(len(b_list)):
        if n % 2 == 0:
            tmp.append(a_list[j])
        else:
            tmp.append(a_inv_list[j])
        n >>= 1

    shared_key = long_to_bytes(crt(tmp, b_list))
    channel = AuthCipher(
        shared_key.rjust(28, b"\x00"), FLAG_CIPHER_KDF_INFO, FLAG_MAC_KDF_INFO,
    )
    key = channel.cipher_key
    cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
    decryptor = cipher.decryptor()
    pt = decryptor.update(data) + decryptor.finalize()
    if pt.startswith(b"CTF"):
        print(pt)

CTF{ChocolateDoesNotAskSillyQuestionsChocolateUnderstands}

TONALITY

30 solves

This challenge is about ECDSA. The source code says the session is as follows:

  • [server] generate private key dd and calculate H=dGH = dG. send HH
  • [client] send tt
  • [server] generate signature (r0,s0)(r_0, s_0) by ECDSA as if private key is tdtd not dd. send (r0,s0)(r_0, s_0).
  • [client] send (r1,s1)(r_1, s_1)
  • [server] verify whether (r1,s1)(r_1, s_1) is a valid signature generated from private key dd. If so, send the flag.

In other words by mathematical notation, to get flag, we should choose tt which can determine (r1,s1)(r_1, s_1) such that:

s0=k01(H(m0)+r0td)modns1=k11(H(m1)+r1d)modn\begin{align*} s_0 &= k_0^{-1} (H(m_0) + r_0 t d) &\mod n \\ s_1 &= k_1^{-1} (H(m_1) + r_1 d) &\mod n \end{align*}

where rir_i equals to kiGk_i G's xx coordinate and nn is the order of GG.

Since s0=k01t(t1H(m0)+r0d)s_0 = k_0^{-1}t(t^{-1}H(m_0) + r_0 d), if we choose (t,r1)(t, r_1) such that t1H(m0)=H(m1)t^{-1} H(m_0) = H(m_1) and r1=r0r_1 = r_0,

s0=k01H(m0)H(m1)1(H(m1)+r1d)modns1=k11(H(m1)+r1d)modn\begin{align*} s_0 &= k_0^{-1} H(m_0)H(m_1)^{-1}(H(m_1) + r_1 d) &\mod n \\ s_1 &= k_1^{-1} (H(m_1) + r_1 d) &\mod n \end{align*}

Note that k1=k0k_1 = k_0 because r1=r0r_1 = r_0. It seems that dividing two equations removes dd and s1=s0H(m1)H(m0)1s_1 = s_0 H(m_1) H(m_0)^{-1}. I sent such (r1,s1)(r_1, s_1) and got the flag!

calculate_hash.go
package main

import (
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/sha1"
	"fmt"
	"io"
	"math/big"
)

func hashMessage(m string) []byte {
	h := sha1.New()
	io.WriteString(h, m)
	return h.Sum(nil)
}

const m0 = "Server says 1+1=2"
const m1 = "Server says 1+1=3"

func hashToInt(hash []byte, c elliptic.Curve) *big.Int {
	orderBits := c.Params().N.BitLen()
	orderBytes := (orderBits + 7) / 8
	if len(hash) > orderBytes {
		hash = hash[:orderBytes]
	}

	ret := new(big.Int).SetBytes(hash)
	excess := len(hash)*8 - orderBits
	if excess > 0 {
		ret.Rsh(ret, uint(excess))
	}
	return ret
}

func main() {
	sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		panic(err)
	}
	hm0 := hashMessage(m0)
	hm1 := hashMessage(m1)
    hm0_int := hashToInt(hm0, sk.Curve)
    hm1_int := hashToInt(hm1, sk.Curve)
	fmt.Println(hm0_int)
	fmt.Println(hm1_int)
}
solve.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import pwnlib
import challenge_pb2
import struct
import sys

from Crypto.Util.number import long_to_bytes, bytes_to_long


def handle_pow(tube):
    raise NotImplemented()


def read_message(tube, typ):
    n = struct.unpack("<L", tube.recvnb(4))[0]
    buf = tube.recvnb(n)
    msg = typ()
    msg.ParseFromString(buf)
    return msg


def write_message(tube, msg):
    buf = msg.SerializeToString()
    tube.send(struct.pack("<L", len(buf)))
    tube.send(buf)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--port", metavar="P", type=int, default=1337, help="challenge #port"
    )
    parser.add_argument(
        "--host",
        metavar="H",
        type=str,
        default="tonality.2021.ctfcompetition.com",
        help="challenge host",
    )
    args = parser.parse_args()

    tube = pwnlib.tubes.remote.remote(args.host, args.port)
    if args.host != "localhost":
        print(tube.recvuntil("== proof-of-work: "))
        if tube.recvline().startswith(b"enabled"):
            handle_pow(tube)

    # Step 1: Hello.
    hello = read_message(tube, challenge_pb2.HelloResponse)

    # Step 2: Sign.
    n = 115792089210356248762697446949407573529996955224135760342422259061068512044369
    hm0 = 542249339167966651829310803902984951686385793376
    hm1 = 1441569821462658164065498394275560753581042711482


    t = hm0 * pow(hm1, -1, n) % n
    sign_req = challenge_pb2.SignRequest()
    sign_req.scalar = t.to_bytes((t.bit_length() + 7) // 8, "big")
    write_message(tube, sign_req)

    sign_res = read_message(tube, challenge_pb2.SignResponse)

    r0 = bytes_to_long(sign_res.message0_sig.r)
    s0 = bytes_to_long(sign_res.message0_sig.s)

    # Step 3: Verify.
    verify_req = challenge_pb2.VerifyRequest()

    r1 = r0
    s1 = s0 * hm1 * pow(hm0, -1, n) % n
    verify_req.message1_sig.r = long_to_bytes(r1)
    verify_req.message1_sig.s = long_to_bytes(s1)
    write_message(tube, verify_req)

    verify_res = read_message(tube, challenge_pb2.VerifyResponse)
    print(verify_res)
    return 0


if __name__ == "__main__":
    sys.exit(main())

CTF{TheySayTheEmptyCanRattlesTheMost}