Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check Sighash Flag #2807

Merged
merged 25 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ruff-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ jobs:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
with:
version: 0.7.4
version: 0.8.2
2 changes: 1 addition & 1 deletion .github/workflows/ruff-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ jobs:
- uses: chartboost/ruff-action@v1
with:
args: "format --check"
version: 0.7.4
version: 0.8.2
4,634 changes: 2,324 additions & 2,310 deletions apiary.apib

Large diffs are not rendered by default.

45 changes: 23 additions & 22 deletions counterparty-core/counterpartycore/lib/api/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@
database.update_version(state_db)


def run_api_server(args, server_ready_value, stop_event):
def run_api_server(args, server_ready_value, stop_event, parent_pid):
logger.info("Starting API Server process...")

def handle_interrupt_signal(signum, frame):
Expand Down Expand Up @@ -513,7 +513,7 @@
wsgi_server = wsgi.WSGIApplication(app, args=args)

logger.info("Starting Parent Process Checker thread...")
parent_checker = ParentProcessChecker(wsgi_server)
parent_checker = ParentProcessChecker(wsgi_server, stop_event, parent_pid)
parent_checker.start()

app.app_context().push()
Expand All @@ -539,52 +539,53 @@
watcher.stop()
watcher.join()

if parent_checker is not None:
logger.trace("Stopping Parent Process Checker thread...")
parent_checker.stop()
parent_checker.join()

logger.info("API Server stopped.")


def is_process_alive(pid):
"""Check For the existence of a unix pid."""
try:

Check warning

Code scanning / pylint

Unnecessary "else" after "return", remove the "else" and de-indent the code inside it. Warning

Unnecessary "else" after "return", remove the "else" and de-indent the code inside it.
os.kill(pid, 0)
except OSError:
return False
else:
return True


# This thread is used for the following two reasons:
# 1. `docker-compose stop` does not send a SIGTERM to the child processes (in this case the API v2 process)
# 2. `process.terminate()` does not trigger a `KeyboardInterrupt` or execute the `finally` block.
class ParentProcessChecker(threading.Thread):
def __init__(self, wsgi_server):
def __init__(self, wsgi_server, stop_event, parent_pid):
super().__init__(name="ParentProcessChecker")
self.daemon = True
self.wsgi_server = wsgi_server
self.stop_event = threading.Event()
self.stop_event = stop_event
self.parent_pid = parent_pid

def run(self):
parent_pid = os.getppid()
try:
while not self.stop_event.is_set():
if os.getppid() != parent_pid:
logger.debug("Parent process is dead. Exiting...")
if self.wsgi_server is not None:
self.wsgi_server.stop()
break
while not self.stop_event.is_set() and is_process_alive(self.parent_pid):
time.sleep(1)
logger.debug("Parent process stopped. Exiting...")
if self.wsgi_server is not None:
self.wsgi_server.stop()
except KeyboardInterrupt:
pass

def stop(self):
self.stop_event.set()


class APIServer(object):
def __init__(self):
def __init__(self, stop_event):
self.process = None
self.server_ready_value = Value("I", 0)
self.stop_event = multiprocessing.Event()
self.stop_event = stop_event

def start(self, args):
if self.process is not None:
raise Exception("API Server is already running")
self.process = Process(
target=run_api_server, args=(vars(args), self.server_ready_value, self.stop_event)
target=run_api_server,
args=(vars(args), self.server_ready_value, self.stop_event, os.getpid()),
)
self.process.start()
return self.process
Expand Down
1 change: 1 addition & 0 deletions counterparty-core/counterpartycore/lib/api/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ def get_tx_info(tx_hex, block_index=None):
db,
deserialize.deserialize_tx(tx_hex, use_txid=use_txid),
block_index=block_index,
composing=True,
)
)
return (
Expand Down
14 changes: 10 additions & 4 deletions counterparty-core/counterpartycore/lib/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,21 @@ def read_transaction(vds, use_txid=True):
offset_before_tx_witnesses = vds.read_cursor - start_pos
for vin in transaction["vin"]: # noqa: B007
witnesses_count = vds.read_compact_size()
for i in range(witnesses_count): # noqa: B007
witness_length = vds.read_compact_size()
witness = vds.read_bytes(witness_length)
transaction["vtxinwit"].append(witness)
if witnesses_count == 0:
transaction["vtxinwit"].append([])
else:
vin_witnesses = []
for i in range(witnesses_count): # noqa: B007
witness_length = vds.read_compact_size()
witness = vds.read_bytes(witness_length)
vin_witnesses.append(witness)
transaction["vtxinwit"].append(vin_witnesses)

transaction["lock_time"] = vds.read_uint32()
data = vds.input[start_pos : vds.read_cursor]

transaction["tx_hash"] = ib2h(double_hash(data))
transaction["tx_id"] = transaction["tx_hash"]
if transaction["segwit"]:
hash_data = data[:4] + data[6:offset_before_tx_witnesses] + data[-4:]
transaction["tx_id"] = ib2h(double_hash(hash_data))
Expand Down
157 changes: 152 additions & 5 deletions counterparty-core/counterpartycore/lib/gettxinfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import binascii
import logging
import struct
from io import BytesIO

from counterpartycore.lib import arc4, backend, config, ledger, message_type, script, util
from counterpartycore.lib.exceptions import BTCOnlyError, DecodeError
Expand Down Expand Up @@ -146,7 +147,8 @@
if "value" in vin:
return vin["value"], vin["script_pub_key"], vin["is_segwit"]

# Note: We don't know what block the `vin` is in, and the block might have been from a while ago, so this call may not hit the cache.
# Note: We don't know what block the `vin` is in, and the block might
# have been from a while ago, so this call may not hit the cache.
vin_ctx = backend.bitcoind.get_decoded_transaction(vin["hash"])

is_segwit = len(vin_ctx["vtxinwit"]) > 0
Expand All @@ -155,11 +157,152 @@
return vout["value"], vout["script_pub_key"], is_segwit


def is_valid_der(der):

Check warning

Code scanning / pylint

Too many return statements (8/6). Warning

Too many return statements (8/6).
if not isinstance(der, bytes):
return False
try:
s = BytesIO(der)
compound = s.read(1)[0]
if compound != 0x30:
return False
length = s.read(1)[0]
if length + 2 != len(der):
return False
marker = s.read(1)[0]
if marker != 0x02:
return False
rlength = s.read(1)[0]
_r = int(s.read(rlength).hex(), 16)
marker = s.read(1)[0]
if marker != 0x02:
return False
slength = s.read(1)[0]
s = int(s.read(slength).hex(), 16)
if len(der) != 6 + rlength + slength:
return False
return True
except Exception:

Check warning

Code scanning / pylint

Catching too general exception Exception. Warning

Catching too general exception Exception.
return False


def is_valid_schnorr(schnorr):
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141

if not isinstance(schnorr, bytes):
return False
if len(schnorr) not in [64, 65]:
return False
if len(schnorr) == 65:
schnorr = schnorr[:-1]
try:
r = int.from_bytes(schnorr[0:32], byteorder="big")
s = int.from_bytes(schnorr[32:64], byteorder="big")
except Exception:

Check warning

Code scanning / pylint

Catching too general exception Exception. Warning

Catching too general exception Exception.
return False
if (r >= p) or (s >= n):
return False
return True


def get_der_signature_sighash_flag(value):
if is_valid_der(value[:-1]):
return value[-1:]
return None


def get_schnorr_signature_sighash_flag(value):

Check warning

Code scanning / pylint

Either all return statements in a function should return an expression, or none of them should. Warning

Either all return statements in a function should return an expression, or none of them should.
if is_valid_schnorr(value):
if len(value) == 65:
return value[-1:]
return b"\x01" # SIGHASH_ALL


def collect_sighash_flags(script_sig, witnesses):
flags = []

# P2PK, P2PKH, P2MS
if script_sig != b"":
asm = script.script_to_asm(script_sig)
for item in asm:
flag = get_der_signature_sighash_flag(item)
if flag is not None:
flags.append(flag)

if len(witnesses) == 0:
return flags

witnesses = [
binascii.unhexlify(witness) if isinstance(witness, str) else witness
for witness in witnesses
]

# P2WPKH
if len(witnesses) == 2:
flag = get_der_signature_sighash_flag(witnesses[0])
if flag is not None:
flags.append(flag)
return flags

# P2TR key path spend
if len(witnesses) == 1:
flag = get_schnorr_signature_sighash_flag(witnesses[0])
if flag is not None:
flags.append(flag)
return flags

# Other cases
if len(witnesses) >= 3:
for item in witnesses:
flag = get_schnorr_signature_sighash_flag(item) or get_der_signature_sighash_flag(item)
if flag is not None:
flags.append(flag)
return flags

return flags


# class SighashFlagError(DecodeError):
class SighashFlagError(Exception):
pass


# known transactions with invalid SIGHASH flag
SIGHASH_FLAG_TRANSACTION_WHITELIST = [
"c8091f1ef768a2f00d48e6d0f7a2c2d272a5d5c8063db78bf39977adcb12e103"
]


def check_signatures_sighash_flag(decoded_tx):
if decoded_tx["tx_id"] in SIGHASH_FLAG_TRANSACTION_WHITELIST:
return

script_sig = decoded_tx["vin"][0]["script_sig"]
witnesses = []
if decoded_tx["segwit"]:
witnesses = decoded_tx["vtxinwit"][0]

flags = collect_sighash_flags(script_sig, witnesses)

if len(flags) == 0:
error = f"impossible to determine SIGHASH flag for transaction {decoded_tx['tx_id']}"
logger.debug(error)
raise SighashFlagError(error)

# first input must be signed with SIGHASH_ALL or SIGHASH_ALL|SIGHASH_ANYONECANPAY
authorized_flags = [b"\x01", b"\x81"]
for flag in flags:
if flag not in authorized_flags:
error = f"invalid SIGHASH flag for transaction {decoded_tx['tx_id']}"
logger.debug(error)
raise SighashFlagError(error)


def get_transaction_sources(decoded_tx):
sources = []
outputs_value = 0

for vin in decoded_tx["vin"][:]: # Loop through inputs.
for vin in decoded_tx["vin"]: # Loop through inputs.
vout_value, script_pubkey, _is_segwit = get_vin_info(vin)

outputs_value += vout_value
Expand Down Expand Up @@ -394,6 +537,8 @@
# Collect all (unique) source addresses.
# if we haven't found them yet
if p2sh_encoding_source is None:
if not composing:
check_signatures_sighash_flag(decoded_tx)
sources, outputs_value = get_transaction_sources(decoded_tx)
if not fee_added:
fee += outputs_value
Expand Down Expand Up @@ -524,7 +669,7 @@
return source, destination, btc_amount, fee, data, []


def _get_tx_info(db, decoded_tx, block_index, p2sh_is_segwit=False):
def _get_tx_info(db, decoded_tx, block_index, p2sh_is_segwit=False, composing=False):
"""Get the transaction info. Calls one of two subfunctions depending on signature type."""
if not block_index:
block_index = util.CURRENT_BLOCK_INDEX
Expand All @@ -535,12 +680,14 @@
decoded_tx,
block_index,
p2sh_is_segwit=p2sh_is_segwit,
composing=composing,
)
elif util.enabled("multisig_addresses", block_index=block_index): # Protocol change.
return get_tx_info_new(
db,
decoded_tx,
block_index,
composing=composing,
)
else:
return get_tx_info_legacy(decoded_tx, block_index)
Expand Down Expand Up @@ -604,7 +751,7 @@
]


def get_tx_info(db, decoded_tx, block_index):
def get_tx_info(db, decoded_tx, block_index, composing=False):
"""Get the transaction info. Returns normalized None data for DecodeError and BTCOnlyError."""
if util.enabled("utxo_support", block_index=block_index):
# utxos_info is a space-separated list of UTXOs, last element is the destination,
Expand All @@ -618,7 +765,7 @@
utxos_info = []
try:
source, destination, btc_amount, fee, data, dispensers_outs = _get_tx_info(
db, decoded_tx, block_index
db, decoded_tx, block_index, composing=composing
)
return source, destination, btc_amount, fee, data, dispensers_outs, utxos_info
except DecodeError as e: # noqa: F841
Expand Down
5 changes: 3 additions & 2 deletions counterparty-core/counterpartycore/lib/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2163,7 +2163,7 @@
return holders


def holders(db, asset, exclude_empty_holders=False):
def holders(db, asset, exclude_empty_holders=False, block_index=None):
Fixed Show fixed Hide fixed

Check warning

Code scanning / pylint

Unused argument 'block_index'. Warning

Unused argument 'block_index'.
"""Return holders of the asset."""
holders = []
cursor = db.cursor()
Expand All @@ -2189,8 +2189,9 @@
SELECT *, rowid
FROM balances
WHERE asset = ? AND utxo IS NOT NULL
ORDER BY rowid DESC
ORDER BY utxo
"""

bindings = (asset,)
cursor.execute(query, bindings)
holders += _get_holders(
Expand Down
12 changes: 8 additions & 4 deletions counterparty-core/counterpartycore/lib/message_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@
return "unknown"

if message_type_id == messages.utxo.ID:
message_data = messages.utxo.unpack(message, return_dict=True)
if util.is_utxo_format(message_data["source"]):
return "detach"
return "attach"
try:
message_data = messages.utxo.unpack(message, return_dict=True)
if util.is_utxo_format(message_data["source"]):
return "detach"
return "attach"
except Exception:

Check warning

Code scanning / pylint

Catching too general exception Exception. Warning

Catching too general exception Exception.
return "unknown"

return TRANSACTION_TYPE_BY_ID.get(message_type_id, "unknown")
Loading
Loading