-
Notifications
You must be signed in to change notification settings - Fork 10
/
build-stager-encrypt.py
168 lines (141 loc) · 4.54 KB
/
build-stager-encrypt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) Kuba Szczodrzyński 2022-09-23.
import sys
from subprocess import PIPE, Popen
from ltchiptool.soc.bk72xx.util import BekenBinary
from ltchiptool.util import CRC16
from ltchiptool.util.intbin import inttobe16
if len(sys.argv) != 4:
print(f"usage: {sys.argv[0]} <mode> <input> <output>")
exit(2)
mode = sys.argv[1]
bk = BekenBinary("510fb093a3cbeadc5993a17ec7adeb03")
xor_key = b"\x55\xaa\x5a\x55\xa5\x4d\x63\x7f"
def find_null(data, source):
try:
i = data.index(b"\x00")
print(f"Found null byte at {hex(i)}: ")
print(source.hex(" ", -1))
print(" " * ((i * 3) + 1) + "^")
exit(1)
except ValueError:
pass
def cmd(*args: str) -> str:
p = Popen(args, stdout=PIPE)
p.wait()
return p.stdout.read().decode().strip()
def xor_encode(var, offs):
key_offs = ((offs - 1) % 8) + 1
key = xor_key[key_offs:] + xor_key[:key_offs]
return bytes(a ^ key[i % 8] for i, a in enumerate(var))
with open(sys.argv[2], "rb") as f:
code = f.read()
if len(code) > 56:
print(f"Code too long! {len(code)} bytes > 56 bytes")
exit(1)
print(f"Code length: {len(code)} bytes")
print(code.hex(" ", -1))
while len(code) < 56:
code += b"\xaa"
blocks: list[dict] = []
if mode.endswith("-standard"):
# Standard (no XOR)
addr = 0x1C5AC0 if mode.startswith("bk7231t") else 0x1B5AC0
blocks += [
dict(
addr=addr + 0x00,
code=code[0:32], # 32 bytes
pre_block=b"11",
),
dict(
addr=addr + 0x20,
code=code[32:56], # 24 bytes
post_code=b"\x00\x00\x00\xff\xff\xff",
# reveng -m crc-16/cms -i 0 -v 000000ffffffffff
# (passwd. last 2 bytes + null term. + rest of block + expected CRC)
revcrc=0x7ADF,
),
]
elif mode == "bk7231n-ip":
# BK7231N
addr = 0x1B5AC0
blocks += [
dict(
addr=addr + 0x00,
code=code[0:32], # 32 bytes
pre_block=b"11",
),
dict(
addr=addr + 0x20,
code=code[32:56], # 24 bytes
post_code=b"\x00\x00\x00\x00\x31\x30", # beginning of IP (10.0.0.x)
# reveng -m crc-16/cms -i 0 -v 0000000031302e30
# (passwd. last 2 bytes + null term. + rest of block + expected CRC)
revcrc=0x33B3,
),
]
elif mode == "bk7231n-xor":
# BK7231N
addr = 0x1B5AC0
blocks += [
dict(
addr=addr + 0x00,
code=code[0:32], # 32 bytes
pre_block=b"11",
xor=True,
xor_offs=0x6C,
),
dict(
addr=addr + 0x20,
code=code[32:56], # 24 bytes
post_code=b"\x55\xaa\x5a\xff\xff\xff",
# reveng -m crc-16/cms -i 0 -v 55aa5affffffffff
# (passwd. last 2 bytes + null term. + rest of block + expected CRC) - XOR-encoded
revcrc=0x67EC,
xor=True,
xor_offs=0x8E,
),
]
else:
print(f"Unknown mode: {mode}")
exit(1)
data = b""
for i, block in enumerate(blocks):
xor = block.get("xor", False)
# encrypt code and find null bytes
block["crypt"] = b"".join(bk.crypt(block["addr"], block["code"]))
find_null(block["crypt"], block["code"])
# add pre-block padding
data += block.get("pre_block", b"")
# add pre-code padding
block["crypt"] = block.get("pre_code", b"") + block["crypt"]
# calculate "fix" to spoof CRC
if "revcrc" in block:
crc = CRC16.CMS.calc(block["crypt"])
fix = inttobe16(crc ^ block["revcrc"])
print(f"Block {i} fix-CRC: {fix.hex()}")
block["crypt"] += fix
find_null(fix, fix)
# store entire block
block_strip = block["crypt"]
if xor:
block_strip = xor_encode(block_strip, block["xor_offs"])
data += block_strip
# add post-code padding
block["crypt"] = block["crypt"] + block.get("post_code", b"")
# add post-block padding
data += block.get("post_block", b"")
# calculate actual CRC
block["crc"] = inttobe16(CRC16.CMS.calc(block["crypt"]))
print(f"Block {i} CRC: {block['crc'].hex()}")
find_null(block["crc"], block["crc"])
# store block CRC for full blocks
if len(block_strip) == 32:
if xor:
block["crc"] = xor_encode(block["crc"], block["xor_offs"] + 0x20)
print(f"Block {i} CRC (XOR): {block['crc'].hex()}")
data += block["crc"]
print(f"Output ({len(data)} bytes):")
print(data.hex(" ", -1))
find_null(data, data)
with open(sys.argv[3], "wb") as f:
f.write(data)