Skip to content

Commit

Permalink
refactor(importer): reintroduce #2644 minus performance regression (#…
Browse files Browse the repository at this point in the history
…2683)

This reverts commit
165b2fe.

Fix performance regression introduced:

- don't create a new storage client and Data Store client for each
  thread processing a GCS blob.
  • Loading branch information
andrewpollock authored Oct 21, 2024
1 parent f67ef32 commit cf73474
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 84 deletions.
207 changes: 134 additions & 73 deletions docker/importer/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,105 @@ def _vuln_ids_from_gcs_blob(self, client: storage.Client,
strict=self._strict_validation)
except Exception as e:
logging.error('Failed to parse vulnerability %s: %s', blob.name, e)
# TODO(andrewpollock): I think this needs to be reraised here...
# a jsonschema.exceptions.ValidationError only gets raised in strict
# validation mode.
return None
for vuln in vulns:
vuln_ids.append(vuln.id)
return vuln_ids

def _convert_blob_to_vuln(
self, storage_client: storage.Client, ndb_client: ndb.Client,
source_repo: osv.SourceRepository, blob: storage.Blob,
ignore_last_import_time: bool) -> Optional[Tuple[str]]:
"""Parse a GCS blob into a tuple of hash and Vulnerability
Criteria for returning a tuple:
- any record in the blob is new (i.e. a new ID) or modified since last run,
and the hash for the blob has changed
- the importer is reimporting the entire source
- ignore_last_import_time is True
- the record passes OSV JSON Schema validation
Usually an OSV file has a single vulnerability in it, but it is permissible
to have more than one, hence it returns a list of tuples.
This is runnable in parallel using concurrent.futures.ThreadPoolExecutor
Args:
storage_client: a storage.Client() to use for retrieval of the blob
ndb_client: an ndb.Client() to use for Data Store access
source_repo: the osv.SourceRepository the blob relates to
blob: the storage.Blob object to operate on
Raises:
jsonschema.exceptions.ValidationError when self._strict_validation is True
input fails OSV JSON Schema validation
Returns:
a list of one or more tuples of (hash, vulnerability) (from the
Vulnerability proto) or None when the blob has an unexpected name
"""
if not _is_vulnerability_file(source_repo, blob.name):
return None

utc_last_update_date = source_repo.last_update_date.replace(
tzinfo=datetime.timezone.utc)

if (not ignore_last_import_time and blob.updated and
blob.updated <= utc_last_update_date):
return None

# The record in GCS appears to be new/changed, examine further.
logging.info('Bucket entry triggered for %s/%s', source_repo.bucket,
blob.name)

# Download in a blob generation agnostic way to cope with the blob
# changing between when it was listed and now (if the generation doesn't
# match, retrieval fails otherwise).
blob_bytes = storage.Blob(
blob.name, blob.bucket,
generation=None).download_as_bytes(storage_client)

blob_hash = osv.sha256_bytes(blob_bytes)

# When self._strict_validation is True,
# this *may* raise a jsonschema.exceptions.ValidationError
vulns = osv.parse_vulnerabilities_from_data(
blob_bytes,
os.path.splitext(blob.name)[1],
strict=self._strict_validation)

# TODO(andrewpollock): integrate with linter here.

# This is the atypical execution path (when reimporting is triggered)
if ignore_last_import_time:
return blob_hash, blob.name

# If being run under test, reuse existing NDB client.
ndb_ctx = ndb.context.get_context(False)
if ndb_ctx is None:
# Production. Use the NDB client passed in.
ndb_ctx = ndb_client.context()
else:
# Unit testing. Reuse the unit test's existing NDB client to avoid
# "RuntimeError: Context is already created for this thread."
ndb_ctx = ndb_ctx.use()

# This is the typical execution path (when reimporting not triggered)
with ndb_ctx:
for vuln in vulns:
bug = osv.Bug.get_by_id(vuln.id)
# The bug already exists and has been modified since last import
if bug is None or \
bug.import_last_modified != vuln.modified.ToDatetime():
return blob_hash, blob.name

return None

return None

def _sync_from_previous_commit(self, source_repo, repo):
"""Sync the repository from the previous commit.
Expand Down Expand Up @@ -444,13 +538,16 @@ def _process_updates_bucket(self, source_repo: osv.SourceRepository):
source_repo.ignore_last_import_time = False
source_repo.put()

# First retrieve a list of files to parallel download
storage_client = storage.Client()
utc_last_update_date = source_repo.last_update_date.replace(
tzinfo=datetime.timezone.utc)

# Get all of the existing records in the GCS bucket
logging.info(
'Listing blobs in gs://%s',
os.path.join(source_repo.bucket,
('' if source_repo.directory_path is None else
source_repo.directory_path)))
# Convert to list to retrieve all information into memory
# This makes its use in the concurrent map later faster
# This makes its concurrent use later faster
listed_blobs = list(
storage_client.list_blobs(
source_repo.bucket,
Expand All @@ -459,86 +556,50 @@ def _process_updates_bucket(self, source_repo: osv.SourceRepository):

import_failure_logs = []

# TODO(andrewpollock): externalise like _vuln_ids_from_gcs_blob()
def convert_blob_to_vuln(blob: storage.Blob) -> Optional[Tuple[str, str]]:
"""Download and parse GCS blob into [blob_hash, blob.name]"""
if not _is_vulnerability_file(source_repo, blob.name):
return None
if not ignore_last_import_time and \
blob.updated is not None and \
not blob.updated > utc_last_update_date:
return None

logging.info('Bucket entry triggered for %s/%s', source_repo.bucket,
blob.name)
# Use the _client_store thread local variable
# set in the thread pool initializer
# Download in a blob generation agnostic way to cope with the file
# changing between when it was listed and now.
blob_bytes = storage.Blob(
blob.name, blob.bucket,
generation=None).download_as_bytes(_client_store.storage_client)
if ignore_last_import_time:
blob_hash = osv.sha256_bytes(blob_bytes)
if self._strict_validation:
try:
_ = osv.parse_vulnerabilities_from_data(
blob_bytes,
os.path.splitext(blob.name)[1],
strict=self._strict_validation)
except Exception as e:
logging.error('Failed to parse vulnerability %s: %s', blob.name, e)
import_failure_logs.append('Failed to parse vulnerability "' +
blob.name + '"')
return None
return blob_hash, blob.name

with _client_store.ndb_client.context():
# Get the hash and the parsed vulnerability from every GCS object that
# parses as an OSV record. Do this in parallel for a degree of expedience.
with concurrent.futures.ThreadPoolExecutor(
max_workers=_BUCKET_THREAD_COUNT) as executor:

logging.info('Parallel-parsing %d blobs in %s', len(listed_blobs),
source_repo.name)
datastore_client = ndb.Client()
future_to_blob = {
executor.submit(self._convert_blob_to_vuln, storage_client,
datastore_client, source_repo, blob,
ignore_last_import_time):
blob for blob in listed_blobs
}

converted_vulns = []
logging.info('Processing %d parallel-parsed blobs in %s',
len(future_to_blob), source_repo.name)

for future in concurrent.futures.as_completed(future_to_blob):
blob = future_to_blob[future]
try:
vulns = osv.parse_vulnerabilities_from_data(
blob_bytes,
os.path.splitext(blob.name)[1],
strict=self._strict_validation)
for vuln in vulns:
bug = osv.Bug.get_by_id(vuln.id)
# Check if the bug has been modified since last import
if bug is None or \
bug.import_last_modified != vuln.modified.ToDatetime():
blob_hash = osv.sha256_bytes(blob_bytes)
return blob_hash, blob.name

return None
if future.result():
converted_vulns.append(([vuln for vuln in future.result() if vuln]))
except Exception as e:
logging.error('Failed to parse vulnerability %s: %s', blob.name, e)
# Don't include error stack trace as that might leak sensitive info
# List.append() is atomic and threadsafe.
import_failure_logs.append('Failed to parse vulnerability "' +
blob.name + '"')
return None

# Setup storage client
def thread_init():
_client_store.storage_client = storage.Client()
_client_store.ndb_client = ndb.Client()
logging.error('Failed to parse vulnerability %s: %s', blob.name, e)
import_failure_logs.append(
'Failed to parse vulnerability (when considering for import) "' +
blob.name + '"')

# TODO(andrewpollock): switch to using c.f.submit() like in
# _process_deletions_bucket()
with concurrent.futures.ThreadPoolExecutor(
_BUCKET_THREAD_COUNT, initializer=thread_init) as executor:
converted_vulns = executor.map(convert_blob_to_vuln, listed_blobs)
for cv in converted_vulns:
if cv:
logging.info('Requesting analysis of bucket entry: %s/%s',
source_repo.bucket, cv[1])
self._request_analysis_external(source_repo, cv[0], cv[1])

replace_importer_log(storage_client, source_repo.name,
self._public_log_bucket, import_failure_logs)
replace_importer_log(storage_client, source_repo.name,
self._public_log_bucket, import_failure_logs)

source_repo.last_update_date = import_time_now
source_repo.put()
source_repo.last_update_date = import_time_now
source_repo.put()

logging.info('Finished processing bucket: %s', source_repo.name)
logging.info('Finished processing bucket: %s', source_repo.name)

def _process_deletions_bucket(self,
source_repo: osv.SourceRepository,
Expand Down Expand Up @@ -598,7 +659,7 @@ def _process_deletions_bucket(self,
logging.info('Parallel-parsing %d blobs in %s', len(listed_blobs),
source_repo.name)
future_to_blob = {
executor.submit(self._vuln_ids_from_gcs_blob, storage.Client(),
executor.submit(self._vuln_ids_from_gcs_blob, storage_client,
source_repo, blob):
blob for blob in listed_blobs
}
Expand Down
Loading

0 comments on commit cf73474

Please sign in to comment.