2026 Dice CTF Qual Writeup - dot

Before all

Rank: 23
Team: fewer

這是我第一次跟 fewer 一起打比賽,解了一題有趣的 Crypto 題
明顯感覺的到近年為了抗衡 AI Agents(事實上,這已經幾乎不可能),題目都走向了更難、code base 更龐大的階段。
畢竟 dice ctf 我去年也有解一題 crypto,難度根本三級跳

這是一道關於 ZKP / SNARG, 可以學到 DPP (dot product proof?) 結構的題目

題目過一陣子應該在 https://github.com/dicegang

Writeup

dot

Server 會走 SNARG 和 DPP 的組合來確認是否對於提供的 a, b,你輸入的 c 和 proof 是可以證明 a + b = c 的。

要偽證 20 次才可以拿到 flag

先來看 dpp.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import random
from dataclasses import dataclass
from itertools import chain
from typing import Iterator
from circuit import AndGate, Circuit, XorGate

random = random.SystemRandom()

State = tuple[int, int]
Vector = list[int]

@dataclass
class LinearConstraint:
scalars: list[tuple[int, int]]
constant: int

def trace_len(circuit: Circuit) -> int:
return len(circuit.inputs) + len(circuit.gates)

def proof_len(circuit: Circuit) -> int:
n = trace_len(circuit)
return n + n * (n + 1) // 2

def pair_index(circuit: Circuit, index1: int, index2: int) -> int:
i = max(index1, index2)
j = min(index1, index2)
return trace_len(circuit) + i * (i + 1) // 2 + j

def input_constraints(circuit: Circuit) -> Iterator[LinearConstraint]:
for i in range(len(circuit.inputs)):
yield LinearConstraint([(i, 1), (pair_index(circuit, i, i), -1)], 0)

def gate_constraints(circuit: Circuit) -> Iterator[LinearConstraint]:
for i, gate in enumerate(circuit.gates):
l = gate.left.index
r = gate.right.index
nl = int(gate.left.constant)
nr = int(gate.right.constant)
sl = 1 - 2 * nl
sr = 1 - 2 * nr

if isinstance(gate, AndGate):
scalars = [(len(circuit.inputs) + i, 1), (pair_index(circuit, l, r), -sl * sr)]
if sl * nr != 0:
scalars.append((l, -sl * nr))
if sr * nl != 0:
scalars.append((r, -sr * nl))
yield LinearConstraint(scalars, nl * nr)
else:
scalars = [(len(circuit.inputs) + i, 1), (l, -sl * sr), (r, -sl * sr), (pair_index(circuit, l, r), 2 * sl * sr)]
yield LinearConstraint(scalars, nl ^ nr)

def output_constraints(circuit: Circuit) -> Iterator[LinearConstraint]:
for wire in circuit.outputs:
sl = 1 - 2 * int(wire.constant)
nl = int(wire.constant)
yield LinearConstraint([(wire.index, sl)], -nl)

def tensor_queries(circuit: Circuit, bound: int) -> tuple[Vector, Vector]:
n = trace_len(circuit)
v = [random.randint(-bound, bound) for _ in range(n)]
q1 = [0] * proof_len(circuit)
q2 = [0] * proof_len(circuit)
for i in range(n):
q1[i] = v[i]
for j in range(i + 1):
q2[pair_index(circuit, i, j)] = v[i] * v[j] if i == j else 2 * v[i] * v[j]
return (q1, q2)

def constraint_query(circuit: Circuit, bound: int) -> tuple[Vector, int]:
query = [0] * proof_len(circuit)
val = 0
constraints = chain(input_constraints(circuit), gate_constraints(circuit), output_constraints(circuit))
for constraint in constraints:
r = random.randint(-bound, bound)
for idx, scalar in constraint.scalars:
query[idx] += r * scalar
val += r * constraint.constant
return (query, val)

def prove(circuit: Circuit, inputs: Vector) -> Vector:
outputs, trace = circuit.evaluate(inputs)
assert all(o == 0 for o in outputs)
proof = trace + [trace[i] * trace[j] for i in range(len(trace)) for j in range(i + 1)]
return proof

def sample(circuit: Circuit, bound1: int, bound2: int) -> tuple[Vector, State]:
n = trace_len(circuit)
b = n * bound1 + 1
q1, q2 = tensor_queries(circuit, bound1)
q3, val = constraint_query(circuit, bound2)
q = [q1[i] + b * (q2[i] - q3[i]) for i in range(proof_len(circuit))]
st = (b, val)
return (q, st)

def answers(st: State, bound: int | None = None) -> Iterator[int]:
b, val = st
if bound is None:
bound = b
for k in range(bound):
rhs = k * k - val
if k == 0:
yield b * rhs
else:
yield k + b * rhs
yield -k + b * rhs

Circuits 先不管,簡言之就是把想要的操作透過 gates 和 wires 扭回去線性乘法

首先,對於 bits b1, b2 而言,xor 就可以定義是 $b_1 + b_2 - 2b_1b_2$ ,同理對於其他位元操作,我們會知道其實他們都可以變成兩變數 與 兩變數乘積的一種線性組合

所以事實上今天可以給定一個向量:
$\pi = {x_1, x_2, x_3, \cdots, x_n, x_1\times x_1, x_1 \times x_2 \cdots x_n \times x_n }$

然後透過事先構造好的向量跟$\pi$內積(dot)來完成每個 bit 組合後是否滿足特定運算的結果的檢查

為此,在 dpp 裡面作了以下定義:

  • $q1$ : 一個長度為 n 的隨機向量
  • $q2$ : $q1||q1’$,其中 $q1’$ 就是類似上面 $\pi$ 後半段的構造方法,不過不同項相乘要改成兩倍因為 $q2$ 最後是要做一個平方項
  • $q3$ : 某些很大的數字用以驗證條件的,在這題是用來透過對剛剛 “事先構造好的向量” $\pi$ 驗證是否滿足加法相等條件的,如果滿足的話那他每次產出的值都會相同 (我們叫它 val)

最後
$q = q1 + B \cdot (q2 - q3)$

其中 B 是一個大數 Bound,當一個滿足的 proof $\pi$ 輸入進來的時候,我們會可以有效預測 $\langle q, \pi \rangle$ 的值(他對 p3 的 dot value 永遠固定,那不固定的就剩下 q1, q2 的部分了,但是根據我們 dot $q$ 出來的結果會等於 $k-B \cdot (k^2-val)$ 其中 $k$ 是 $\langle q1, \pi \rangle$)。

於是就可以建立答案池,根據 k 和固定的 val 值去推算所有結果。
這樣做的方法是透過前面加入一個 k 以及一個大 Bound 讓攻擊者如果想透過線性補償滿足 $(q2 - q3)$ 部分的計算依然需要考慮 q1 導致無解。

snarg.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import hashlib
import math
from itertools import islice
from multiprocessing import Pool
from typing import BinaryIO
from fastecdsa.curve import P256
from fastecdsa.keys import gen_private_key
from fastecdsa.point import Point
from fastecdsa.encoding.sec1 import SEC1Encoder
from tqdm import tqdm

from circuit import Circuit
import dpp


Proof = tuple[Point, Point]
State = tuple[int, list[int], set[bytes]]

BOUND1 = 2**8
BOUND2 = 2**40

def hash_to_point(i: int) -> Point:
p = P256.p
for ctr in range(256):
x = int.from_bytes(hashlib.sha256(i.to_bytes(8, 'little') + bytes([ctr])).digest()) % p
y_sq = (pow(x, 3, p) - 3 * x + P256.b) % p
y = pow(y_sq, (p + 1) // 4, p)
if pow(y, 2, p) == y_sq:
return Point(x, y, curve=P256)
raise ValueError(f'hash_to_point failed for i={i}')

def compute_c(args: tuple[int, int, int]) -> bytes:
i, qi, sk = args
c = qi * P256.G + sk * hash_to_point(i)
return SEC1Encoder.encode_public_key(c, compressed=True)

def setup(circuit: Circuit, crs: BinaryIO, vk: BinaryIO) -> None:
n = len(circuit.inputs)
q, st = dpp.sample(circuit, BOUND1, BOUND2)
sk = gen_private_key(P256)
vk.write(f'{sk}\n'.encode())
vk.write((','.join(str(qi) for qi in q[:n]) + '\n').encode())

with Pool() as pool:
args = ((i, qi, sk) for i, qi in enumerate(q[n:]))
for c_enc in tqdm(pool.imap(compute_c, args), total=len(q) - n):
crs.write(c_enc)

# allow for small completeness error
table_size = 10 * round(math.sqrt(dpp.trace_len(circuit) / 6) * BOUND1)
for a in tqdm(islice(dpp.answers(st), table_size), total=table_size):
p = a * P256.G
p_enc = SEC1Encoder.encode_public_key(p, compressed=True)
vk.write(p_enc)

def prove(circuit: Circuit, inputs: list[int], pk: BinaryIO) -> Proof:
dpp_proof = dpp.prove(circuit, inputs)
n = len(circuit.inputs)
h1 = Point._identity_element()
h2 = Point._identity_element()
for i, t in enumerate(tqdm(dpp_proof[n:])):
c_enc = pk.read(33)
if t == 0:
continue
c = SEC1Encoder.decode_public_key(c_enc, P256)
h1 += t * hash_to_point(i)
h2 += t * c
proof = (h1, h2)
return proof

def vk_state(vk: BinaryIO) -> State:
sk = int(vk.readline())
q_inputs = [int(x) for x in vk.readline().split(b',')]
table = set()
while p_enc := vk.read(33):
table.add(p_enc)
return (sk, q_inputs, table)

def verify(inputs: list[int], st: State, proof: Proof) -> bool:
sk, q_inputs, table = st
assert len(inputs) == len(q_inputs)
assert all(x in (0, 1) for x in inputs)
h1, h2 = proof
p = h2 - sk * h1
input_sum = sum(q_inputs[i] * inputs[i] for i in range(len(inputs)))
p += input_sum * P256.G
p_enc = SEC1Encoder.encode_public_key(p, compressed=True)
return p_enc in table

在 snarg.py 中定義了整個 SNARG 協議,我們可以特別注意到是基於橢圓曲線的,一開始有個生成元 $G$ 以及一個 HASH 函數把 int 值 HASH 後對上曲線上的點,定義它叫做 $H(i)$。

公開出來的公鑰(crs.bin 檔案內容)包含這些點:
$C_i = q_i \cdot G + sk \cdot H_i$

整個 $q$ 向量是私鑰,還有我們的 $sk$

Prover 每次要提供的就是兩個點:
$h_1 = \sum_{i=n}^{M} \pi_i \cdot H_i$
$h_2 = \sum_{i=n}^{M} \pi_i \cdot C_i$
M 就是剛剛那個 $\pi$ 的長度

根據前面講 dpp 的方法他這樣就可以還原點的結果並跟答案池對答案:
$P = h_2 - sk \cdot h_1 + (\langle \pi, p1 \rangle \cdot G)$

最後是這題的 server.py (dist file 並未提供私鑰 vk.bin)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/local/bin/python3
import secrets
from fastecdsa.curve import P256
from fastecdsa.encoding.sec1 import SEC1Encoder

import snarg
from add import int_to_bits

if __name__ == '__main__':
n = 64

with open('vk.bin', 'rb') as f:
st = snarg.vk_state(f)

streak = 0
while True:
a = secrets.randbits(n)
b = secrets.randbits(n)
print(f'what is {a} + {b}? (mod 2^64)')

while True:
c = int(input('answer: '))
assert 0 <= c < (1 << 64)
correct = c == (a + b) % (1 << n)

proof_buf = bytes.fromhex(input('proof: '))
assert len(proof_buf) == 2 * 33
h1 = SEC1Encoder.decode_public_key(proof_buf[:33], P256)
h2 = SEC1Encoder.decode_public_key(proof_buf[33:], P256)

inputs = int_to_bits(a, n) + int_to_bits(b, n) + int_to_bits(c, n)
proof = (h1, h2)
valid = snarg.verify(inputs, st, proof)

if valid and correct:
print('correct! but that was obvious...')
streak = 0
elif valid and not correct:
print('huh?')
streak += 1
if streak >= 20:
print(open('flag.txt').read().strip())
exit()
break
else:
streak = 0
print('wrong...')

最後,這題的漏洞是對於 BOUND1 (snarg.py 中)的定義過小
code trace 注意到 dpp.py 這一段

1
2
3
4
5
6
7
8
9
10
def tensor_queries(circuit: Circuit, bound: int) -> tuple[Vector, Vector]:
n = trace_len(circuit)
v = [random.randint(-bound, bound) for _ in range(n)]
q1 = [0] * proof_len(circuit)
q2 = [0] * proof_len(circuit)
for i in range(n):
q1[i] = v[i]
for j in range(i + 1):
q2[pair_index(circuit, i, j)] = v[i] * v[j] if i == j else 2 * v[i] * v[j]
return (q1, q2)

這邊對於 q2 向量的構造其實在 BOUND1 是 2**8 範圍下的時候,假設攻擊者翻 LSB 讓值偏離一下(這樣就可以滿足 a+b 不等於 c),因為 LSB 相關的只會有第 128 bit 的 q 以及後續相乘的時候在 wire 裡面的 l 值,他們都在我們小小的 BOUND1 範圍內,所以分批枚舉就好。

而這剛好是可以枚舉出來的,注意到對於 $h_1$ 點的值因為 hash 是公開地所以很好 fake。
$h_2$ 的部分,則是因為只有偏移 LSB,爆破兩個參數的時候再最後面加上本次猜測的值會跟原本的正確答案偏移多少倍的點 $G$ 給它補回去試試看就知道了。

其他我覺得看 exploit 更快(?) 一樣是用 Gemini 修好的
exp.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
from pwn import *
import snarg
import dpp
from add import build_adder, int_to_bits
from fastecdsa.curve import P256
from fastecdsa.encoding.sec1 import SEC1Encoder

snarg.tqdm = lambda x, **kwargs: x

def main():
print("[*] Building the circuit...")
n = 64
circuit = build_adder(n)
num_inputs = 192

idx_c0 = 128
l_idx = 192

pair_c0 = dpp.pair_index(circuit, idx_c0, idx_c0)
pair_l = dpp.pair_index(circuit, l_idx, idx_c0)

print(f"[*] Extracting commitment points from crs.bin...")
with open('crs.bin', 'rb') as f:
f.seek((pair_c0 - num_inputs) * 33)
C_pair_c0 = SEC1Encoder.decode_public_key(f.read(33), P256)

f.seek((pair_l - num_inputs) * 33)
C_pair_l = SEC1Encoder.decode_public_key(f.read(33), P256)

H_pair_c0 = snarg.hash_to_point(pair_c0 - num_inputs)
H_pair_l = snarg.hash_to_point(pair_l - num_inputs)

q_mod = P256.q
half = pow(2, -1, q_mod)
b_val = dpp.trace_len(circuit) * 256 + 1

print("[*] Connecting to target...")
r = remote('dot.chals.dicec.tf', 1337)
# r = process(['python3', 'server.py'])

print("\n=== Phase 1: High-Speed Pipeline Brute-Force ===")
r.recvuntil(b'what is ')
line = r.recvline().decode().strip()
parts = line.split(' ')
a, b = int(parts[0]), int(parts[2][:-1])

c_correct = (a + b) % (1 << 64)
c_wrong = c_correct ^ 1
delta = 1 if c_wrong > c_correct else -1

print("[*] Computing honest proof to serve as base...")
inputs = int_to_bits(a, 64) + int_to_bits(b, 64) + int_to_bits(c_correct, 64)
with open('crs.bin', 'rb') as pk:
h1, h2 = snarg.prove(circuit, inputs, pk)

d_c0 = (delta % q_mod) * C_pair_c0
d_l = ((delta * half) % q_mod) * C_pair_l
h2_base = h2 + d_c0 + d_l

dh1_c0 = (delta % q_mod) * H_pair_c0
dh1_l = ((delta * half) % q_mod) * H_pair_l
h1_prime = h1 + dh1_c0 + dh1_l
h1_prime_hex = SEC1Encoder.encode_public_key(h1_prime, True).hex()

v128_correct, vl_correct = None, None
found = False

for v128 in range(-256, 257):
sys.stdout.write(f"\r[*] Testing chunk v128 = {v128:4d} (Sending 513 payloads...)")
sys.stdout.flush()

K_base = (delta * (v128 + b_val * v128**2)) % q_mod
K_step = (delta * b_val * v128) % q_mod
K_start = (K_base - 256 * K_step) % q_mod

H_current = h2_base + (-K_start % q_mod) * P256.G
H_step = (-K_step % q_mod) * P256.G

payload = b""
for vl in range(-256, 257):
h2_prime_hex = SEC1Encoder.encode_public_key(H_current, True).hex()
payload += f"{c_wrong}\n{h1_prime_hex}{h2_prime_hex}\n".encode()
H_current = H_current + H_step

if len(payload) > 16384:
r.send(payload)
payload = b""

if payload:
r.send(payload)

responses = 0
while responses < 513:
data = r.recv(8192, timeout=3.0)
if not data: break

if b'huh?' in data:
wrongs_before = data.split(b'huh?')[0].count(b'wrong')
vl_correct = -256 + responses + wrongs_before
v128_correct = v128
print(f"\n[+] BINGO! Found server secrets -> v128: {v128_correct}, vl: {vl_correct}")
found = True
break
responses += data.count(b'wrong')

if found: break

r.close()

print("\n=== Phase 2: Rapid Fire 20 Streak ===")
print("[*] Reconnecting for a clean streak...")
r = remote('dot.chals.dicec.tf', 1337)
# r = process(['python3', 'server.py'])

for streak in range(20):
r.recvuntil(b'what is ')
line = r.recvline().decode().strip()
parts = line.split(' ')
a, b = int(parts[0]), int(parts[2][:-1])
c_correct = (a + b) % (1 << 64)
c_wrong = c_correct ^ 1
delta = 1 if c_wrong > c_correct else -1

inputs = int_to_bits(a, 64) + int_to_bits(b, 64) + int_to_bits(c_correct, 64)
with open('crs.bin', 'rb') as pk:
h1, h2 = snarg.prove(circuit, inputs, pk)

h1_prime = h1 + (delta % q_mod) * H_pair_c0 + ((delta * half) % q_mod) * H_pair_l
h2_base = h2 + (delta % q_mod) * C_pair_c0 + ((delta * half) % q_mod) * C_pair_l

K = (delta * (v128_correct + b_val * v128_correct**2 + b_val * vl_correct * v128_correct)) % q_mod
h2_prime = h2_base + (-K % q_mod) * P256.G

h1_hex = SEC1Encoder.encode_public_key(h1_prime, True).hex()
h2_hex = SEC1Encoder.encode_public_key(h2_prime, True).hex()

r.sendlineafter(b'answer: ', str(c_wrong).encode())
r.sendlineafter(b'proof: ', f"{h1_hex}{h2_hex}".encode())

res = r.recvline().decode().strip()
print(f"[*] Streak {streak + 1}/20 => {res}")
if "{" in res or "flag" in res.lower() or streak == 19:
print("\n[+] FLAG FOUND:")
print(r.recvall(timeout=1).decode())
break

if __name__ == '__main__':
main()