Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check Sighash Flag #2807

Merged
merged 25 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ruff-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ jobs:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
with:
version: 0.7.4
version: 0.8.2
2 changes: 1 addition & 1 deletion .github/workflows/ruff-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ jobs:
- uses: chartboost/ruff-action@v1
with:
args: "format --check"
version: 0.7.4
version: 0.8.2
4,634 changes: 2,324 additions & 2,310 deletions apiary.apib

Large diffs are not rendered by default.

45 changes: 23 additions & 22 deletions counterparty-core/counterpartycore/lib/api/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@
database.update_version(state_db)


def run_api_server(args, server_ready_value, stop_event):
def run_api_server(args, server_ready_value, stop_event, parent_pid):
logger.info("Starting API Server process...")

def handle_interrupt_signal(signum, frame):
Expand Down Expand Up @@ -513,7 +513,7 @@
wsgi_server = wsgi.WSGIApplication(app, args=args)

logger.info("Starting Parent Process Checker thread...")
parent_checker = ParentProcessChecker(wsgi_server)
parent_checker = ParentProcessChecker(wsgi_server, stop_event, parent_pid)
parent_checker.start()

app.app_context().push()
Expand All @@ -539,52 +539,53 @@
watcher.stop()
watcher.join()

if parent_checker is not None:
logger.trace("Stopping Parent Process Checker thread...")
parent_checker.stop()
parent_checker.join()

logger.info("API Server stopped.")


def is_process_alive(pid):
"""Check For the existence of a unix pid."""
try:

Check warning

Code scanning / pylint

Unnecessary "else" after "return", remove the "else" and de-indent the code inside it. Warning

Unnecessary "else" after "return", remove the "else" and de-indent the code inside it.
os.kill(pid, 0)
except OSError:
return False
else:
return True


# This thread is used for the following two reasons:
# 1. `docker-compose stop` does not send a SIGTERM to the child processes (in this case the API v2 process)
# 2. `process.terminate()` does not trigger a `KeyboardInterrupt` or execute the `finally` block.
class ParentProcessChecker(threading.Thread):
def __init__(self, wsgi_server):
def __init__(self, wsgi_server, stop_event, parent_pid):
super().__init__(name="ParentProcessChecker")
self.daemon = True
self.wsgi_server = wsgi_server
self.stop_event = threading.Event()
self.stop_event = stop_event
self.parent_pid = parent_pid

def run(self):
parent_pid = os.getppid()
try:
while not self.stop_event.is_set():
if os.getppid() != parent_pid:
logger.debug("Parent process is dead. Exiting...")
if self.wsgi_server is not None:
self.wsgi_server.stop()
break
while not self.stop_event.is_set() and is_process_alive(self.parent_pid):
time.sleep(1)
logger.debug("Parent process stopped. Exiting...")
if self.wsgi_server is not None:
self.wsgi_server.stop()
except KeyboardInterrupt:
pass

def stop(self):
self.stop_event.set()


class APIServer(object):
def __init__(self):
def __init__(self, stop_event):
self.process = None
self.server_ready_value = Value("I", 0)
self.stop_event = multiprocessing.Event()
self.stop_event = stop_event

def start(self, args):
if self.process is not None:
raise Exception("API Server is already running")
self.process = Process(
target=run_api_server, args=(vars(args), self.server_ready_value, self.stop_event)
target=run_api_server,
args=(vars(args), self.server_ready_value, self.stop_event, os.getpid()),
)
self.process.start()
return self.process
Expand Down
1 change: 1 addition & 0 deletions counterparty-core/counterpartycore/lib/api/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ def get_tx_info(tx_hex, block_index=None):
db,
deserialize.deserialize_tx(tx_hex, use_txid=use_txid),
block_index=block_index,
composing=True,
)
)
return (
Expand Down
14 changes: 10 additions & 4 deletions counterparty-core/counterpartycore/lib/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,21 @@ def read_transaction(vds, use_txid=True):
offset_before_tx_witnesses = vds.read_cursor - start_pos
for vin in transaction["vin"]: # noqa: B007
witnesses_count = vds.read_compact_size()
for i in range(witnesses_count): # noqa: B007
witness_length = vds.read_compact_size()
witness = vds.read_bytes(witness_length)
transaction["vtxinwit"].append(witness)
if witnesses_count == 0:
transaction["vtxinwit"].append([])
else:
vin_witnesses = []
for i in range(witnesses_count): # noqa: B007
witness_length = vds.read_compact_size()
witness = vds.read_bytes(witness_length)
vin_witnesses.append(witness)
transaction["vtxinwit"].append(vin_witnesses)

transaction["lock_time"] = vds.read_uint32()
data = vds.input[start_pos : vds.read_cursor]

transaction["tx_hash"] = ib2h(double_hash(data))
transaction["tx_id"] = transaction["tx_hash"]
if transaction["segwit"]:
hash_data = data[:4] + data[6:offset_before_tx_witnesses] + data[-4:]
transaction["tx_id"] = ib2h(double_hash(hash_data))
Expand Down
122 changes: 117 additions & 5 deletions counterparty-core/counterpartycore/lib/gettxinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def get_vin_info(vin):
if "value" in vin:
return vin["value"], vin["script_pub_key"], vin["is_segwit"]

# Note: We don't know what block the `vin` is in, and the block might have been from a while ago, so this call may not hit the cache.
# Note: We don't know what block the `vin` is in, and the block might
# have been from a while ago, so this call may not hit the cache.
vin_ctx = backend.bitcoind.get_decoded_transaction(vin["hash"])

is_segwit = len(vin_ctx["vtxinwit"]) > 0
Expand All @@ -155,11 +156,118 @@ def get_vin_info(vin):
return vout["value"], vout["script_pub_key"], is_segwit


def get_der_signature_sighash_flag(value):
if not isinstance(value, bytes):
return None
lenght_by_prefix = {
"3044": 71,
"3045": 72,
"3046": 73,
"3041": 68,
"3042": 69,
"3043": 70,
}
for prefix, length in lenght_by_prefix.items():
if value.startswith(binascii.unhexlify(prefix)) and len(value) == length:
return value[-1:]
return None


def get_schnorr_signature_sighash_flag(value):

Check warning

Code scanning / pylint

Either all return statements in a function should return an expression, or none of them should. Warning

Either all return statements in a function should return an expression, or none of them should.
if not isinstance(value, bytes):
return None
if len(value) not in [64, 65]:
return None
if len(value) == 65:
return value[-1:]
return b"\x01" # SIGHASH_ALL by default


def collect_sighash_flags(script_sig, witnesses):
flags = []

# P2PK, P2PKH, P2MS
if script_sig != b"":
asm = script.script_to_asm(script_sig)
for item in asm:
flag = get_der_signature_sighash_flag(item)
if flag is not None:
flags.append(flag)

if len(witnesses) == 0:
return flags

witnesses = [
binascii.unhexlify(witness) if isinstance(witness, str) else witness
for witness in witnesses
]

# P2WPKH
if len(witnesses) == 2:
flag = get_der_signature_sighash_flag(witnesses[0])
if flag is not None:
flags.append(flag)
return flags

# P2TR key path spend
if len(witnesses) == 1:
flag = get_schnorr_signature_sighash_flag(witnesses[0])
if flag is not None:
flags.append(flag)
return flags

# Other cases
if len(witnesses) >= 3:
for item in witnesses:
flag = get_schnorr_signature_sighash_flag(item) or get_der_signature_sighash_flag(item)
if flag is not None:
flags.append(flag)
return flags

return flags


# class SighashFlagError(DecodeError):
class SighashFlagError(Exception):
pass


# known transactions with invalid SIGHASH flag
SIGHASH_FLAG_TRANSACTION_WHITELIST = [
"c8091f1ef768a2f00d48e6d0f7a2c2d272a5d5c8063db78bf39977adcb12e103"
]


def check_signatures_sighash_flag(decoded_tx):
if decoded_tx["tx_id"] in SIGHASH_FLAG_TRANSACTION_WHITELIST:
return

script_sig = decoded_tx["vin"][0]["script_sig"]
witnesses = []
if decoded_tx["segwit"]:
witnesses = decoded_tx["vtxinwit"][0]

flags = collect_sighash_flags(script_sig, witnesses)

if len(flags) == 0:
error = f"impossible to determine SIGHASH flag for transaction {decoded_tx['tx_id']}"
logger.debug(error)
raise SighashFlagError(error)

# first input must be signed with SIGHASH_ALL or SIGHASH_ALL|SIGHASH_ANYONECANPAY
authorized_flags = [b"\x01", b"\x81"]
for flag in flags:
if flag not in authorized_flags:
error = f"invalid SIGHASH flag for transaction {decoded_tx['tx_id']}"
logger.debug(error)
raise SighashFlagError(error)


def get_transaction_sources(decoded_tx):
sources = []
outputs_value = 0

for vin in decoded_tx["vin"][:]: # Loop through inputs.
for vin in decoded_tx["vin"]: # Loop through inputs.
vout_value, script_pubkey, _is_segwit = get_vin_info(vin)

outputs_value += vout_value
Expand Down Expand Up @@ -394,6 +502,8 @@ def get_tx_info_new(db, decoded_tx, block_index, p2sh_is_segwit=False, composing
# Collect all (unique) source addresses.
# if we haven't found them yet
if p2sh_encoding_source is None:
if not composing:
check_signatures_sighash_flag(decoded_tx)
sources, outputs_value = get_transaction_sources(decoded_tx)
if not fee_added:
fee += outputs_value
Expand Down Expand Up @@ -524,7 +634,7 @@ def get_tx_info_legacy(decoded_tx, block_index):
return source, destination, btc_amount, fee, data, []


def _get_tx_info(db, decoded_tx, block_index, p2sh_is_segwit=False):
def _get_tx_info(db, decoded_tx, block_index, p2sh_is_segwit=False, composing=False):
"""Get the transaction info. Calls one of two subfunctions depending on signature type."""
if not block_index:
block_index = util.CURRENT_BLOCK_INDEX
Expand All @@ -535,12 +645,14 @@ def _get_tx_info(db, decoded_tx, block_index, p2sh_is_segwit=False):
decoded_tx,
block_index,
p2sh_is_segwit=p2sh_is_segwit,
composing=composing,
)
elif util.enabled("multisig_addresses", block_index=block_index): # Protocol change.
return get_tx_info_new(
db,
decoded_tx,
block_index,
composing=composing,
)
else:
return get_tx_info_legacy(decoded_tx, block_index)
Expand Down Expand Up @@ -604,7 +716,7 @@ def get_utxos_info(db, decoded_tx):
]


def get_tx_info(db, decoded_tx, block_index):
def get_tx_info(db, decoded_tx, block_index, composing=False):
"""Get the transaction info. Returns normalized None data for DecodeError and BTCOnlyError."""
if util.enabled("utxo_support", block_index=block_index):
# utxos_info is a space-separated list of UTXOs, last element is the destination,
Expand All @@ -618,7 +730,7 @@ def get_tx_info(db, decoded_tx, block_index):
utxos_info = []
try:
source, destination, btc_amount, fee, data, dispensers_outs = _get_tx_info(
db, decoded_tx, block_index
db, decoded_tx, block_index, composing=composing
)
return source, destination, btc_amount, fee, data, dispensers_outs, utxos_info
except DecodeError as e: # noqa: F841
Expand Down
12 changes: 8 additions & 4 deletions counterparty-core/counterpartycore/lib/message_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@
return "unknown"

if message_type_id == messages.utxo.ID:
message_data = messages.utxo.unpack(message, return_dict=True)
if util.is_utxo_format(message_data["source"]):
return "detach"
return "attach"
try:
message_data = messages.utxo.unpack(message, return_dict=True)
if util.is_utxo_format(message_data["source"]):
return "detach"
return "attach"
except Exception:

Check warning

Code scanning / pylint

Catching too general exception Exception. Warning

Catching too general exception Exception.
return "unknown"

return TRANSACTION_TYPE_BY_ID.get(message_type_id, "unknown")
6 changes: 2 additions & 4 deletions counterparty-core/counterpartycore/lib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ def extract_bitcoincore_config():
"rpcssl": "backend-ssl",
}

for bitcoind_key in config_keys:
for bitcoind_key, counterparty_key in config_keys.items():
if bitcoind_key in conf:
counterparty_key = config_keys[bitcoind_key]
bitcoincore_config[counterparty_key] = conf[bitcoind_key]

return bitcoincore_config
Expand Down Expand Up @@ -144,9 +143,8 @@ def server_to_client_config(server_config):
"rpc-password": "counterparty-rpc-password",
}

for server_key in config_keys:
for server_key, client_key in config_keys.items():
if server_key in server_config:
client_key = config_keys[server_key]
client_config[client_key] = server_config[server_key]

return client_config
Expand Down
7 changes: 6 additions & 1 deletion counterparty-core/counterpartycore/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import binascii
import decimal
import logging
import multiprocessing
import os
import sys
import tarfile
Expand Down Expand Up @@ -739,6 +740,7 @@ def start_all(args):
follower_daemon = None
asset_conservation_checker = None
db = None
api_stop_event = None

# Log all config parameters, sorted by key
# Filter out default values #TODO: these should be set in a different way
Expand Down Expand Up @@ -766,7 +768,8 @@ def start_all(args):
check.software_version()

# API Server v2
api_server_v2 = api_v2.APIServer()
api_stop_event = multiprocessing.Event()
api_server_v2 = api_v2.APIServer(api_stop_event)
api_server_v2.start(args)
while not api_server_v2.is_ready() and not api_server_v2.has_stopped():
logger.trace("Waiting for API server to start...")
Expand Down Expand Up @@ -812,6 +815,8 @@ def start_all(args):
logger.error("Exception caught!", exc_info=e)
finally:
# Ensure all threads are stopped
if api_stop_event:
api_stop_event.set()
if api_status_poller:
api_status_poller.stop()
if api_server_v1:
Expand Down
Loading
Loading