zer0pts CTF 2022 Writeup

Sun Mar 20 2022

3/19-20 に開催していた zer0pts CTF 2022 にチーム WreckTheLine で参加しました。結果は 17th/632 でした。

毎度 zer0pts の人たちが作る CTF は面白いのでとても楽しみにしていた CTF です。去年の zer0pts CTF 参加時と比べると特に Crypto が解けるようになってきたので (去年の時点では楕円曲線も LLL もよくわからないレベルだった)、 Crypto の力試しという気持ちがありました。結果的には1問解けなかったので悔しいです、精進は終わらない… 以下解いた問題についての writeup です。

crypto

Karen

8 solves

task.sage
with open("flag.txt", "rb") as f:
    flag = int.from_bytes(f.read(), "big")

n = 70
m = flag.bit_length()
assert n < m
p = random_prime(2**512)


F = GF(p)
x = random_matrix(F, 1, n)
A = random_matrix(ZZ, n, m, x=0, y=2)
A[randint(0, n-1)] = vector(ZZ, Integer(flag).bits())
h = x*A

print(p)
print(list(h[0]))

x(Z/pZ)nx \in (\Z/p\Z)^nA{0,1}n×mA \in \{0, 1\}^{n\times m} がランダムに生成されます。このうち、 AA の一つの行はフラグのビット列となっています。与えられているのは ppmodp\mod p の空間で計算された xAxA です。これらから x,Ax, A を復元することができればフラグが入手できます。

この問題は部分和問題から派生した問題に見え、既に先行研究がありそうなため、ググるところから始めました。ちょっと調べてみると、この問題は hidden subset sum problem と呼ばれているみたいです。もう少し調べると https://eprint.iacr.org/2020/461.pdf を見つけました。 Nguyen-Stern algorithm という古典的な手法 (1997年) があるみたいです (※上記リンクの論文自体は、 Nguyen-Stern algorithm を改良したという話です)。 Nguyen-Stern algorithm sagemath とかで調べると、 https://gist.github.com/grocid/62081c82c077eae83f61a9c03b405c84 で sagemath 実装を見つけました。

…なのですみません、 script kiddy をしてしまいました…やっていることはただの OSINT です。上記 gist の script を走らせて x,Ax, A を復元し、 AA のそれぞれの行ベクトルをバイト列に変換し、フラグとなっているものを見つければ OK です。

zer0pts{Karen_likes_orthogonal_as_you_like}

アルゴリズムは全然理解していないので後で余力があれば追記しようと思います。でも多分作問者 writeup とかで解説してくれそうな気がするし、それを読んだら満足する気がする…

この問題は途中で配布ファイルが変更されており、変更前は n=8n = 8 でした。自分は変更前のファイルでしばらくの間解いてました… n=8n = 8 だと hihj(0i,j<m)h_i - h_j (0 \le i, j < m)xkx_kpxkp - x_k となる確率がかなり高いため、そこから xx を求めることが可能です。 しかし当然 n=70n = 70 では hihjh_i - h_j は相異なる数となり、この解法は通りません。自分はファイル更新に気づかず、 admin の人にファイルが間違ってないかと質問してしまいました。ファイル更新のアナウンスをまず確認すべきでしたね。大変申し訳ないです。

EDDH

24 solves

server.py
from random import randrange
from Crypto.Util.number import inverse, long_to_bytes
from Crypto.Cipher import AES
from hashlib import sha256
import ast
import os
import signal

n = 256
p = 64141017538026690847507665744072764126523219720088055136531450296140542176327
a = 362
d = 1
q = 64141017538026690847507665744072764126693080268699847241685146737444135961328
c = 4
gx = 36618472676058339844598776789780822613436028043068802628412384818014817277300
gy = 9970247780441607122227596517855249476220082109552017755637818559816971965596

def xor(xs, ys):
    return bytes(x^y for x, y in zip(xs, ys))

def pad(b, l):
    return b + b"\0" + b"\xff" * (l - (len(b) + 1))

def unpad(b):
    l = -1
    while b[l] != 0:
        l -= 1
    return b[:l]

def add(P, Q):
    (x1, y1) = P
    (x2, y2) = Q

    x3 = (x1*y2 + y1*x2) * inverse(1 + d*x1*x2*y1*y2, p) % p
    y3 = (y1*y2 - a*x1*x2) * inverse(1 - d*x1*x2*y1*y2, p) % p
    return (x3, y3)

def mul(x, P):
    Q = (0, 1)
    x = x % q
    while x > 0:
        if x % 2 == 1:
            Q = add(Q, P)
        P = add(P, P)
        x = x >> 1
    return Q

def to_bytes(P):
    x, y = P
    return int(x).to_bytes(n // 8, "big") + int(y).to_bytes(n // 8, "big")

def send(msg, share):
    assert len(msg) <= len(share)
    print(xor(pad(msg, len(share)), share).hex())

def recv(share):
    inp = input()
    msg = bytes.fromhex(inp)
    assert len(msg) <= len(share)
    return unpad(xor(msg, share))

def main():
    signal.alarm(300)

    flag = os.environ.get("FLAG", "0nepoint{frog_pyokopyoko_3_pyokopyoko}")
    assert len(flag) < 2*8*n
    while len(flag) % 16 != 0:
        flag += "\0"

    G = (gx, gy)
    s = randrange(0, q)

    print("sG = {}".format(mul(s, G)))
    tG = ast.literal_eval(input("tG = "))  # you should input something like (x, y)
    assert len(tG) == 2
    assert type(tG[0]) == int and type(tG[1]) == int
    share = to_bytes(mul(s, tG))

    while True:
        msg = recv(share)
        if msg == b"flag":
            aes = AES.new(key=sha256(long_to_bytes(s)).digest(), mode=AES.MODE_ECB)
            send(aes.encrypt(flag.encode()), share)

        elif msg == b"quit":
            quit()

        else:
            send(msg, share)

if __name__ == '__main__':
    main()

Edwards curve 上で ECDH をする問題で、相手の生成した秘密の値 (これなんて言うんだろう、秘密鍵ではない気がする) がわかればフラグを復元できます。

楕円曲線のパラメータがハードコードされている場合、 Pohlig-Hellman が使えるかが初手で気になるところなので確認します。位数として qq が定義されているっぽいので qq を素因数分解してみると、十分大きな素数の積となっており、ナイーブには Pohlig-Hellman が使えないことがわかります。

この問題を注意深く見ると、実は addmul を計算するときに、 (x,y)(x, y) が Edwards curve 上にあるかどうかをチェックしていません。なので qq とは異なる (素因数も異なるという意味) 値で周期構造が生じうる可能性があります。 例えば s×(0,y)s \times (0, y) という計算をすると、 (0,ys)(0, y^s) となることがわかります (add 関数の x1,x2x_1, x_2 に0を代入するとわかる)。これの位数は p1p - 1 です。 p1p - 1 を素因数分解すると、 Pohlig-Hellman が適用できる程度に十分小さい素数の積となっていることがわかります。

以上の考察から (0,2)(0, 2) などの値を tG として送信すれば share の値から s を求めることができます。しかし DH 鍵共有の手順に則っていないため、 share を単純には知ることができません。 今回の問題では送信した msg に対して share を使って計算した値を返してくれます。これを利用します。 msg\x00 とすると、返ってくる値は空文字列に pad したものと share とで xor したものになります。これから share が特定できます。

手順をまとめると、 (0,2)(0, 2)tG として送信→ share のリーク→ ss の特定→暗号化したフラグを入手→復号となります。

solve.sage
import re
from binascii import unhexlify
from hashlib import sha256

from Crypto.Cipher import AES
from Crypto.Util.number import long_to_bytes
from pwn import remote

Z = Zmod(p)

io = remote("crypto.ctf.zer0pts.com", 10929)
ret = io.recvline().strip().decode()
sGx, sGy = map(int, re.findall(r"sG = \((.*), (.*)\)", ret)[0])
sG = (sGx, sGy)
io.sendlineafter(b"tG = ", b"(0, 2)")
io.sendline(b"00")
send_message = int(io.recvline().strip().decode(), 16)
shared_int = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff ^^ send_message
s = Z(shared_int).log(2)
share = to_bytes(mul(s, (0, 2)))
io.sendline(b"666c616700")
enc = unhexlify(io.recvline().strip())
enc_xor = unpad(xor(enc, share))
aes = AES.new(key=sha256(long_to_bytes(s)).digest(), mode=AES.MODE_ECB)
print(aes.decrypt(enc_xor))

zer0pts{edwards_what_the_hell_is_this}

CurveCrypto

36 solves

task.py
import os
from random import randrange
from Crypto.Util.number import bytes_to_long, long_to_bytes, getStrongPrime
from Crypto.Util.Padding import pad
from fastecdsa.curve import Curve

def xgcd(a, b):
    x0, y0, x1, y1 = 1, 0, 0, 1
    while b != 0:
        q, a, b = a // b, b, a % b
        x0, x1 = x1, x0 - q * x1
        y0, y1 = y1, y0 - q * y1
    return a, x0, y0

def gen():
    while True:
        p = getStrongPrime(512)
        if p % 4 == 3:
            break
    while True:
        q = getStrongPrime(512)
        if q % 4 == 3:
            break
    n = p * q
    a = randrange(n)
    b = randrange(n)

    while True:
        x = randrange(n)
        y2 = (x**3 + a*x + b) % n
        assert y2 % n == (x**3 + a*x + b) % n
        if pow(y2, (p-1)//2, p) == 1 and pow(y2, (q-1)//2, q) == 1:
            yp, yq = pow(y2, (p + 1) // 4, p), pow(y2, (q + 1) // 4, q)
            _, s, t = xgcd(p, q)
            y = (s*p*yq + t*q*yp) % n
            break
    return Curve(None, n, a, b, None, x, y) 

def encrypt(m, G):
    blocks = [m[16*i:16*(i+1)] for i in range(len(m) // 16)]
    c = []
    for i in range(len(blocks)//2):
        G = G + G
        c.append(G.x ^ bytes_to_long(blocks[2*i]))
        c.append(G.y ^ bytes_to_long(blocks[2*i+1]))
    return c

def decrypt(c, G):
    m = b''
    for i in range(len(c) // 2):
        G = G + G
        m += long_to_bytes(G.x ^ c[2*i])
        m += long_to_bytes(G.y ^ c[2*i+1])
    return m

flag = pad(os.environ.get("FLAG", "fakeflag{sumomomomomomomomonouchi_sumomo_mo_momo_mo_momo_no_uchi}").encode(), 32)
C = gen()
c = encrypt(flag, C.G)
assert decrypt(c, C.G) == flag

print("n = {}".format(C.p))
print("a = {}".format(C.a))
print("b = {}".format(C.b))
print("c = {}".format(c))

Z/nZ\Z/n\Z (nn はクソデカ素数の積) 上での楕円曲線のパラメータが与えられます。曲線上の1つの点 GG が定義されており、 2G,4G2G, 4Gx,yx, y 座標とフラグの xor を取った暗号も与えられています。

フラグの値は未知なのと、 xor は算術的な計算に不向きなため、 cc'cc の下位128ビットをマスクしたものとして、 (2G)x=c0+Δx2,(2G)y=c1+Δy2,(4G)x=c2+Δx4,(4G)y=c3+Δy4(2G)_x = c'_0 + \Delta x_2, (2G)_y = c'_1 + \Delta y_2, (4G)_x = c'_2 + \Delta x_4, (4G)_y = c'_3 + \Delta y_4 とおきます。 Δxi,Δyi\Delta x_i, \Delta y_i は128ビット程度で、他の値と比べると十分小さいです。

当然 2G+2G=4G2G + 2G = 4G なため、 (4G)x,(4G)y(4G)_x, (4G)_y はどちらも (2G)x,(2G)y(2G)_x, (2G)_y を用いて書き表すことができます。また 2G,4G2G, 4G はどちらも楕円曲線上の点なため (2G)x,(2G)y(2G)_x, (2G)_y(4G)x,(4G)y(4G)_x, (4G)_y の関係式も書き表せます。 4変数に対して4式が立てられたのでこれを解いていきます。グレブナー基底を使えば簡約化された多項式が生成されるので、それに対して coppersmith を用いて変数を求めていきます。多変数 coppersmith には毎度おなじみ defund/coppersmith を用いました。 この手続きで4変数を求めたら、もとの cc と xor を取ってフラグが復元できます。

…ここまで書いて気づきましたが、2G,4G2G, 4G が楕円曲線上であるという式は十分に次元が小さい多項式かつそれぞれ2変数のみなので、その式で coppersmith 使えば普通に求まりますね…下のコードはグレブナー基底を使ったものになってますが賢い人は polys[2], polys[3] それぞれで coppersmith を使いましょう (というか多分グレブナー基底で簡約化された多項式と多分一致している。手計算で polys[0], polys[1] を計算するの大変だったけどいらんかったな…)。

solve.sage
masked_c = [cc >> 128 << 128 for cc in c]

PRall.<dx2, dy2, dx4, dy4> = PolynomialRing(Zmod(n))
polys = [
    4 * (masked_c[1] + dy2)**2 * (masked_c[2] + dx4 + 2*masked_c[0] + 2*dx2) - (3 * (masked_c[0] + dx2) ** 2 + a) ** 2,
    2 * (masked_c[1] + dy2) * (masked_c[3] + dy4) - (3 * (masked_c[0] + dx2)**2 + a) * (masked_c[0] + dx2 - masked_c[2] - dx4) + 2 * (masked_c[1] + dy2)**2,
    (masked_c[1] + dy2) ** 2 - (masked_c[0] + dx2) ** 3 - a*(masked_c[0] + dx2) - b,
    (masked_c[3] + dy4) ** 2 - (masked_c[2] + dx4) ** 3 - a*(masked_c[2] + dx4) - b,
]
I = Ideal(polys)
basis = I.groebner_basis()

PR4.<dx4, dy4> = PolynomialRing(Zmod(n))
f = PR4(basis[-1])
res = small_roots(f, [2**128]*2, m=2, d=3)
dx4, dy4 = res[0]

PR2.<dx2, dy2> = PolynomialRing(Zmod(n))
f = PR2(str(basis[-2](dx4=dx4, dy4=dy4)))
res = small_roots(f, [2**128]*2, m=3, d=4)
dx2, dy2 = res[0]

G2x = masked_c[0] + dx2
G2y = masked_c[1] + dy2
G4x = masked_c[2] + dx4
G4y = masked_c[3] + dy4

m = b""
m += long_to_bytes(c[0] ^^ int(G2x))
m += long_to_bytes(c[1] ^^ int(G2y))
m += long_to_bytes(c[2] ^^ int(G4x))
m += long_to_bytes(c[3] ^^ int(G4y))

zer0pts{th3_g00d_3ncrypti0n_c0m3s_fr0m_th3_g00d_curv3}

Anti-Fermat

125 solves

task.py
from Crypto.Util.number import isPrime, getStrongPrime
from gmpy import next_prime
from secret import flag

# Anti-Fermat Key Generation
p = getStrongPrime(1024)
q = next_prime(p ^ ((1<<1024)-1))
n = p * q
e = 65537

# Encryption
m = int.from_bytes(flag, 'big')
assert m < n
c = pow(m, e, n)

print('n = {}'.format(hex(n)))
print('c = {}'.format(hex(c)))

素数の生成方法が特殊な RSA です。 qq0x1111....pp の xor をした値より大きい整数で最も小さい素数となっています。 このような生成方法ですので、 p+qp + q210242^{1024} よりちょっと大きい値となっています。 pq=n,p+q=21024+dpq = n, p + q = 2^{1024} + d として dd を for で回すと、ある dd のときに z2(21024+d)z+n=0z^2 - (2^{1024} + d)z + n = 0 の解が求まります。その求まった解が p,qp, q なのでいつもの手順で RSA の復号をすれば OK です。

solve.sage
from Crypto.Util.number import long_to_bytes

PR.<z> = PolynomialRing(ZZ)
for d in range(4000):
    p_q = 2**1024 + d
    f = z**2 - p_q * z + n
    roots = f.roots()
    if len(roots) != 0:
        break

p, q = roots[0][0], roots[1][0]
phi = (p - 1) * (q - 1)
d = int(pow(0x10001, -1, phi))
long_to_bytes(int(pow(c, d, n)))

zer0pts{F3rm4t,y0ur_m3th0d_n0_l0ng3r_w0rks.y0u_4r3_f1r3d}

misc

MathHash

57 solves

server.py
import struct
import math
import signal
import os

def MathHash(m):
    hashval = 0
    for i in range(len(m)-7):
        c = struct.unpack('<Q', m[i:i+8])[0]
        t = math.tan(c * math.pi / (1<<64))
        hashval ^= struct.unpack('<Q', struct.pack('<d', t))[0]
    return hashval

if __name__ == '__main__':
    FLAG = os.getenv('FLAG', 'zer0pts<sample_flag>').encode()
    assert FLAG.startswith(b'zer0pts')

    signal.alarm(1800)
    try:
        while True:
            key = bytes.fromhex(input("Key: "))
            assert len(FLAG) >= len(key)

            flag = FLAG
            for i, c in enumerate(key):
                flag = flag[:i] + bytes([(flag[i] + key[i]) % 0x100]) + flag[i+1:]

            h = MathHash(flag)
            print("Hash: " + hex(h))
    except:
        exit(0)

入力した key によって flag が改変され、その改変された flag のビット列を8ビットずつに区切り tan 関数に通したあとで再びビット列に戻し…といった操作が行われています。

まずフラグの文字列長を探索させました。手動2分探索しました。

Key: 00112233445566778899001122334455667788990011223344
Hash: 0x2e2bd0fd230716
Key: 0011223344556677889900112233445566778899001122334455
(ここで落ちる)

この結果から25文字であることがわかります。

手元で MathHash に与える値をいろいろ変えて実験していると、 b"\x00\x00...\x00\x00" + b"\x00" のときと b"\x00\x00...\x00\x00" + c (cb"\x00" 以外の任意の値) で MathHash の出力が大きく変化します。 あまり深くは理解していないですが、 tanh が連続関数であるものの、 struct.pack('<d', t)t = 0 周辺で不連続になっているからこのような挙動になっていると思われます。

key のロジックを見ると、各バイトで flag + key をしている感じです。これとフラグの前方が zer0pts{ となっていることを利用して、前方から1文字ずつフラグを特定できます。

"MathHash の出力が大きく変化する" が感覚的な理解なため、毎 iteration で目視確認して文字を特定しています…ダサい

solve.py
from Crypto.Util.number import long_to_bytes
from pwn import remote


io = remote("misc.ctf.zer0pts.com", 10001)

flag = b"zer0pts{"
key = bytes([256 - c for c in flag])
for idx in range(25 - len(flag)):
    values = []
    for i in range(256):
        tmp_key = key + long_to_bytes(i)
        io.sendlineafter(b"Key: ", tmp_key.hex())
        io.recvuntil(b"Hash: ")
        value = int(io.recvline().strip().decode(), 16)
        values.append(value)
        print(i, hex(value))
    m = int(input())
    key += long_to_bytes(m)
    flag += long_to_bytes(256 - m)
    print(flag, key)

zer0pts{s1gn+|3xp^|fr4c.}

web

GitFile Explorer

181 solves

index.php
<?php
function h($s) { return htmlspecialchars($s); }
function craft_url($service, $owner, $repo, $branch, $file) {
    if (strpos($service, "github") !== false) {
        /* GitHub URL */
        return $service."/".$owner."/".$repo."/".$branch."/".$file;

    } else if (strpos($service, "gitlab") !== false) {
        /* GitLab URL */
        return $service."/".$owner."/".$repo."/-/raw/".$branch."/".$file;

    } else if (strpos($service, "bitbucket") !== false) {
        /* BitBucket URL */
        return $service."/".$owner."/".$repo."/raw/".$branch."/".$file;

    }

    return null;
}

$service = empty($_GET['service']) ? "" : $_GET['service'];
$owner   = empty($_GET['owner'])   ? "ptr-yudai" : $_GET['owner'];
$repo    = empty($_GET['repo'])    ? "ptrlib"    : $_GET['repo'];
$branch  = empty($_GET['branch'])  ? "master"    : $_GET['branch'];
$file    = empty($_GET['file'])    ? "README.md" : $_GET['file'];

if ($service) {
    $url = craft_url($service, $owner, $repo, $branch, $file);
    if (preg_match("/^http.+\/\/.*(github|gitlab|bitbucket)/m", $url) === 1) {
        $result = file_get_contents($url);
    }
}
?>
<!DOCTYPE html>
<html>
    <head>
        <meta charset="UTF-8">
        <title>GitFile Explorer</title>
        <link rel="stylesheet" href="https://cdn.simplecss.org/simple-v1.css">
    </head>
    <body>
        <header>
            <h1>GitFile Explorer API Test</h1>
            <p>Simple API to download files on GitHub/GitLab/BitBucket</p>
        </header>
        <main>
            <form method="GET" action="/">
                <label for="service">Service: </label>
                <select id="service" name="service" autocomplete="off">
                    <option value="https://raw.githubusercontent.com" <?= strpos($service, "github") === false ? "" : 'selected="selected"' ?>>GitHub</option>
                    <option value="https://gitlab.com" <?= strpos($service, "gitlab") === false ? "" : 'selected="selected"' ?>>GitLab</option>
                    <option value="https://bitbucket.org" <?= strpos($service, "bitbucket") === false ? "" : 'selected="selected"' ?>>BitBucket</option>
                </select>
                <br>
                <label for="owner">GitHub ID: </label>
                <input id="owner" name="owner" type="text" placeholder="Repository Owner" value="<?= h($owner); ?>">
                <br>
                <label for="repo">Repository Name: </label>
                <input id="repo" name="repo" type="text" placeholder="Repository Name" value="<?= h($repo); ?>">
                <br>
                <label for="branch">Branch: </label>
                <input id="branch" name="branch" type="text" placeholder="Branch Name" value="<?= h($branch); ?>">
                <br>
                <label for="file">File Path: </label>
                <input id="file" name="file" type="text" placeholder="README.md" value="<?= h($file); ?>">
                <br>
                <input type="submit" value="Download">
            </form>
            <?php if (isset($result)) { ?>
                <br>
                <?php if ($result === false) { ?>
                    <p>Not Found :(</p>
                <?php } else {?>
                    <textarea rows="20" cols="40"><?= h($result); ?></textarea>
                <?php } ?>
            <?php } ?>
        </main>
        <footer>
            <p>zer0pts CTF 2022</p>
        </footer>
    </body>
</html>

craft_url の返り値が "/^http.+\/\/.*(github|gitlab|bitbucket)/m" を満たしているとき、その返り値の path のファイルを file_get_contents で見ることができます。

file_get_contents は URL を入れた場合はその URL の、 file path を入れた場合は local のファイルを表示します。このアプリケーションの想定した使われ方 は URL ですが、これを file path に捏造し、 /flag.txt を見る方法はないでしょうか。 正規表現を見てみると、 http:// で始まる必要は実はなくて、 http// とかでも通ります。このとき URL の scheme にはなりません。なのでこのような path を作れるような引数を craft_url に渡してあげます。 craft_urlgithub 等の文字列さえ入っていれば query parameter をただ結合しているだけです。自分は http://gitfile.ctf.zer0pts.com:8001/?service=https//github&owner=hoge&repo=fuga&branch=piyo&file=foo/../../../../../../../../../flag.txt で通しました。

zer0pts{foo/bar/../../../../../directory/traversal}

復習

crypto: OK

server.py
from Crypto.Util.number import isPrime, getPrime, getRandomRange, inverse
import os
import signal

signal.alarm(300)

flag = os.environ.get("FLAG", "0nepoint{GOLDEN SMILE & SILVER TEARS}")
flag = int(flag.encode().hex(), 16)

P = 2 ** 1000 - 1
while not isPrime(P): P -= 2

p = getPrime(512)
q = getPrime(512)
e = 65537
phi = (p-1)*(q-1)
d = inverse(e, phi)
n = p*q

key = getRandomRange(0, n)
ciphertext = pow(flag, e, P) ^ key

x1 = getRandomRange(0, n)
x2 = getRandomRange(0, n)

print("P = {}".format(P))
print("n = {}".format(n))
print("e = {}".format(e))
print("x1 = {}".format(x1))
print("x2 = {}".format(x2))

# pick a random number k and compute v = k**e + (x1|x2)
# if you add x1, you can get key = c1 - k mod n
# elif you add x2, you can get ciphertext = c2 - k mod n
v = int(input("v: "))

k1 = pow(v - x1, d, n)
k2 = pow(v - x2, d, n)

print("c1 = {}".format((k1 + key) % n))
print("c2 = {}".format((k2 + ciphertext) % n))

vv をこちらから指定ができ、そこから key ^ ciphertext の値をリークさせる問題です。 key ^ ciphertext がわかれば PP が素数なので簡単にフラグが入手できます。 以下では keykkciphertextcckey ^ ciphertextxx と表します。

競技中

この問題では対称性があります。添字の1が key に、2が ciphertext に対応している感じです。なのでこの対称性を崩さないような vv を選択したほうが見通しがよさそうです。 x1x2mod2x_1 \equiv x_2 \mod 2 のとき v=(x1+x2)/2v = (x_1 + x_2) / 2 とすると、 k1=((x2x1)/2)d,k2=((x1x2)/2)d=k1(d1mod2)k_1 = ((x_2 - x_1)/2)^d, k_2 = ((x_1 - x_2)/2)^d = -k_1 (\because d \equiv 1 \mod 2) が成り立ちます。これを利用すると、 c1+c2=k1+k+k2+c=k+cc_1 + c_2 = k_1 + k + k_2 + c = k + c となり、 k+cmodnk + c \mod n の値がわかります。 modn\mod n の不定性は排除したいです。 0k+c<2n0 \le k + c < 2n なので k+c=c1+c2i×n k + c = c_1 + c_2 - i \times n と書いたとき i=1,0,1i = -1, 0, 1 のいずれかになります。 k+c=x+2(kc)k + c = x + 2(k \land c) と表せることを利用すると、 k+ck + c の lsb は xx と一致します。なので xx のパリティを仮定してしまえば、 k+ck + c のパリティが xx と一致し、 0k+c<2n0 \le k + c < 2n となるように c1+c2c_1 + c_2nn を足し引きすることで exact に求まります。

再び k+c=x+2(kc)k + c = x + 2(k \land c) を考えます。 kck \land c は75%程度が0で25%程度が1となっている数のはずです。つまり k+cxk + c - x の値にはビットの偏りが存在します。 これを利用するため、 k+ck + c を数10パターン取得し、 k+cxk + c - x をバイナリで表したときに0が最大になるような xx を探索させました。

# csums == c + k の array

def search_i(args):
    i, x_bits = args
    n = len(x_bits)
    tmp = 2 ** n * i + int(x_bits, 2)
    cnt = Counter()
    for j in range(len(csums)):
        csum = csums[j]
        cnt += Counter(f"{csum - tmp:01024b}")
    n0 = cnt["0"]
    return n0

def search(x_bits, depth=1):
    with ProcessPoolExecutor() as executor:
        n0s = executor.map(search_i, [(i, x_bits) for i in range(2**depth)])
    return np.argmax(list(n0s))


x_bits = "1"  # lsb を1と仮定
for idx in tqdm(range(999)):
    i = search(x_bits, depth=11)
    x_bits = str(i % 2) + x_bits

lsb 側から1ビットずつ決めることを考えています。 depth bit 分の数を全探索させてもっとも0が多くなった数の lsb を x_bits に足していっています。

実験だとこの方法は95%程度のビットは特定できるのですが、残り5%が正しく求まりません。この間違っている5%の場所がどこかもわからないのでここで手詰まりでした…つらい

競技後

discord での議論を眺めていると、もっと単純に xx を1ビットずつ求める方法があるみたいでした。

xxii ビット目を注目すると、この ii ビット目の carry (足し算の繰り上がり) が0であろうが1であろうが不変な特徴量を見つけられると嬉しいです。 ii ビット目だけをみてもこのような特徴量は存在しなそうなので i+1i + 1ビット目も注目します。頭で考えてもわからないので真理値表を書き出してみます。

from collections import defaultdict
from itertools import product


d = defaultdict(list)
for c0, k0, c1, k1, car0 in product(range(2), repeat=5):
    d["c0"].append(c0)
    d["k0"].append(k0)
    d["c1"].append(c1)
    d["k1"].append(k1)
    d["car0"].append(car0)
    d["c0 ^ k0"].append(c0 ^ k0)
    d["c0 + k0"].append(c0 ^ k0 ^ car0)
    car1 = 1 if c0 + k0 + car0 >= 2 else 0
    d["c1 + k1"].append(c1 ^ k1 ^ car1)
df = pd.DataFrame(d)
df
    c0  k0  c1  k1  car0  c0 ^ k0  c0 + k0  c1 + k1
0    0   0   0   0     0        0        0        0
1    0   0   0   0     1        0        1        0
2    0   0   0   1     0        0        0        1
3    0   0   0   1     1        0        1        1
4    0   0   1   0     0        0        0        1
5    0   0   1   0     1        0        1        1
6    0   0   1   1     0        0        0        0
7    0   0   1   1     1        0        1        0
8    0   1   0   0     0        1        1        0
9    0   1   0   0     1        1        0        1
10   0   1   0   1     0        1        1        1
11   0   1   0   1     1        1        0        0
12   0   1   1   0     0        1        1        1
13   0   1   1   0     1        1        0        0
14   0   1   1   1     0        1        1        0
15   0   1   1   1     1        1        0        1
16   1   0   0   0     0        1        1        0
17   1   0   0   0     1        1        0        1
18   1   0   0   1     0        1        1        1
19   1   0   0   1     1        1        0        0
20   1   0   1   0     0        1        1        1
21   1   0   1   0     1        1        0        0
22   1   0   1   1     0        1        1        0
23   1   0   1   1     1        1        0        1
24   1   1   0   0     0        0        0        1
25   1   1   0   0     1        0        1        1
26   1   1   0   1     0        0        0        0
27   1   1   0   1     1        0        1        0
28   1   1   1   0     0        0        0        0
29   1   1   1   0     1        0        1        0
30   1   1   1   1     0        0        0        1
31   1   1   1   1     1        0        1        1

c0, k0, car0 が下位のビットで car0 は carry を表しています。 c1, k1 は隣の上位ビットです。 この真理値表から、 c0, k0, c1, k1 を固定して car0 が変わっても変わらない特徴量を気合で探します。 c0 ^ k0 が1のときは c0 + k0c1 + k1 の値が必ず同じか必ず異なっているのに対し、 c0 ^ k0 が0のときはランダムに変化します。これに気づくのむずすぎませんか…?

solve.py
from pwn import remote

io = remote("crypto.ctf.zer0pts.com", 10333)


def recv_value(symbol):
    io.recvuntil(f"{symbol} = ")
    return int(io.recvline())


c1s = []
c2s = []
ns = []

for _ in range(100):
    io = remote("crypto.ctf.zer0pts.com", 10333)
    P = recv_value("P")
    n = recv_value("n")
    e = recv_value("e")
    x1 = recv_value("x1")
    x2 = recv_value("x2")
    if (x1 + x2) % 2 == 1:
        io.close()
        continue
    v = (x1 + x2) // 2
    io.sendlineafter(b"v: ", str(v).encode())
    c1 = recv_value("c1")
    c2 = recv_value("c2")
    c1s.append(c1)
    c2s.append(c2)
    ns.append(n)
    io.close()

csums = []
enc_flag_parity = 1  # x の lsb を1と仮定
for i in range(len(c1s)):
    c1 = c1s[i]
    c2 = c2s[i]
    n = ns[i]
    tmp_sum = c1 + c2
    res = None
    if tmp_sum % 2 == enc_flag_parity:
        res = tmp_sum
    else:
        if tmp_sum >= n:
            res = tmp_sum - n
        else:
            res = tmp_sum + n
    csums.append(res)

M = 20
dec = ""
for i in reversed(range(1, 1001)):
    cnt = Counter()
    for j in range(M):
        csum = csums[j]
        s1 = (csum & (1 << i)) // (1 << i)
        s0 = (csum & (1 << (i - 1))) // (1 << (i - 1))
        cnt[s0 == s1] += 1
    if cnt.most_common(1)[0][1] == M:
        dec += "1"
    else:
        dec += "0"

P = 10715086071862673209484250490600018105614048117055336074437503883703510511249361224931983788156958581275946729175531468251871452856923140435984577574698574803934567774824230985421074605062371141877954182153046474983581941267398767559165543946077062914571196477686542167660429831652624386837205668068131
e = 0x10001
flag_d = int(pow(e, -1, P - 1))
long_to_bytes(pow(int(dec, 2), flag_d, P))

zer0pts{hav3_y0u_unwittin91y_acquir3d_th3_k3y_t0_th3_d00r_t0_th3_N3w_W0r1d?}

crypto: Karen

コピペで解いてしまったので、その原理のお気持ちを復習しました。参考にした solver は https://gist.github.com/grocid/62081c82c077eae83f61a9c03b405c84 です。 問題では m=351m = 351 ですが、以下の解法だと大きすぎて LLL の計算時間が長くなるため、 m=176m = 176 程度で計算させます。

全体の流れは

  • Mht=0modpMh^t = 0 \mod p となる MZm×mM \in \Z^{m\times m} を求める
  • MM を LLL で簡約し、新たに MM とする
  • MKt=0modpMK^t = 0 \mod p となる KZn×mK \in \Z^{n \times m} を求める
  • KK を LLL で簡約し、 1,0,1-1, 0, 1 のみからなる行列とする。これを新たに KK とする
  • KK のベクトルの足し引きで 0,10, 1 のみからなるベクトルを nn 個つくる。これがもとの行列 AA である。

となります。

  • Mht=0modpMh^t = 0 \mod p となる MZm×mM \in \Z^{m\times m} を求める

これは雑に作ればいいだけです。

def orthogonal_matrix(h, p):
    """Returns M such that Mh^t = 0 mod p
    h: row vector
    p: integer
    """
    m = h.length()
    M = matrix(ZZ, m, m)
    M[0, 0] = p
    for i in range(1, m):
        M[i, i] = 1
    M[1: m, 0] = -h[1: m] * pow(h[0], -1, p) % p
    return M


M = orthogonal_matrix(h, p)
tmp = M * h
assert all([tmp[i] % p == 0 for i in range(m)])
  • MM を LLL で簡約し、新たに MM とする

これもやるだけ。

M = M.LLL()
  • MKt=0modpMK^t = 0 \mod p となる KZn×mK \in \Z^{n \times m} を求める

これは .right_kernel() で計算できます。

K = M[:m-n].right_kernel().matrix()[:n]

このあたりで今までの計算を整理しておきます。 MKtMK^t の上位 mnm - n 行はすべて0になっています。また、 MMhh に直交しているため、 Mht=MAtxt=0Mh^t = MA^tx^t = 0 です。したがって KK の行ベクトルの線形和で AA を表せることを示唆しています。これを以下では求めていきます。

  • KK を LLL で簡約し、 1,0,1-1, 0, 1 のみからなる行列とする。これを新たに KK とする

LLL を使えば 1,0,1-1, 0, 1 のみからなる行列を作れることが期待できます。上記 solver では block_size を徐々に大きくしながら BKZ を使って解いています (あまり原理はわかっていない…)。この問題では LLL で達成できたので深追いしません。

K = K.LLL()
  • KK のベクトルの足し引きで 0,10, 1 のみからなるベクトルを nn 個つくる。これがもとの行列 AA である。

0,10, 1 のみのベクトルを探索させて nn 個見つけます。前のステップで 1,0,1-1, 0, 1 のみになっているので探索が十分に可能です。

vs = []
for i in range(n):
    if len(set(K[i])) == 2:
        v = K[i]
        if -1 in v:
            v = -v
        break
vs.append(v)
for v in vs:
    for i in range(n):
        set_plus = set(v + K[i])
        set_minus = set(v - K[i])
        if set_plus in [{0, 1}, {0, -1}]:
            tmp = v + K[i]
        elif set_minus in [{0, 1}, {0, -1}]:
            tmp = v - K[i]
        else:
            continue
        if -1 in tmp:
            tmp = -tmp
        if tmp not in vs:
            vs.append(tmp)
A = matrix(ZZ, vs)

これで AA が求まりました。 この問題では AA を求めるだけで十分ですが、 xx を求めたい場合は

x = A[:, :n].solve_left(h[:n]) % p

で求まります。