Skip to content

Commit

Permalink
Fix logic in _get_algorithm_definitions to avoid skipping algorithm d…
Browse files Browse the repository at this point in the history
…efinitions (#498)
  • Loading branch information
alexklibisz authored Mar 19, 2024
1 parent 4c8b1c1 commit df8083a
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions ann_benchmarks/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str) -> Dict[st
metric. For example, `ann_benchmarks.algorithms.nmslib` has two definitions for euclidean float
data: specifically `SW-graph(nmslib)` and `hnsw(nmslib)`, even though the module is named nmslib.
If an algorithm has an 'any' distance metric is found for the specific point type, it is used
regardless (and takes precendence) over if the distance metric is present.
If an algorithm has an 'any' distance metric, it is also included.
Returns: A mapping from the algorithm name (not the algorithm class), to the algorithm definitions, i.e.:
```
Expand Down Expand Up @@ -195,11 +194,10 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str) -> Dict[st
# param `_` is filename, not specific name
for _, config in configs.items():
c = []
if "any" in config: # "any" branch must come first
c = config["any"]
elif distance_metric in config:
c = config[distance_metric]

if "any" in config:
c.extend(config["any"])
if distance_metric in config:
c.extend(config[distance_metric])
for cc in c:
definitions[cc.pop("name")] = cc

Expand Down Expand Up @@ -359,4 +357,4 @@ def get_definitions(
)


return definitions
return definitions

0 comments on commit df8083a

Please sign in to comment.