Skip to content

Commit

Permalink
fix faiss setup to take dataframe as input (#378)
Browse files Browse the repository at this point in the history
* fix faiss setup to take dataframe as input

* use make df in faiss test
  • Loading branch information
jperez999 authored Jul 1, 2023
1 parent 56b3adc commit 57b8211
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
16 changes: 12 additions & 4 deletions merlin/systems/dag/ops/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np

from merlin.core.dispatch import HAS_GPU
from merlin.core.protocols import Transformable
from merlin.core.protocols import DataFrameLike, Transformable
from merlin.dag import ColumnSelector
from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ops.operator import InferenceOperator
Expand Down Expand Up @@ -189,7 +189,13 @@ def validate_schemas(
)


def setup_faiss(item_vector, output_path: str, metric=faiss.METRIC_INNER_PRODUCT):
def setup_faiss(
item_vector: DataFrameLike,
output_path: str,
metric=faiss.METRIC_INNER_PRODUCT,
item_id_column="item_id",
embedding_column="embedding",
):
"""
Utiltiy function that will create a Faiss index from a set of embedding vectors
Expand All @@ -200,8 +206,10 @@ def setup_faiss(item_vector, output_path: str, metric=faiss.METRIC_INNER_PRODUCT
output_path : string
target output path
"""
ids = item_vector[:, 0].astype(np.int64)
item_vectors = np.ascontiguousarray(item_vector[:, 1:].astype(np.float32))
ids = item_vector[item_id_column].to_numpy().astype(np.int64)
item_vectors = np.ascontiguousarray(
np.stack(item_vector[embedding_column].to_numpy()).astype(np.float32)
)

index = faiss.index_factory(item_vectors.shape[1], "IVF32,Flat", metric)
index.nprobe = 8
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/systems/ops/faiss/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import pytest

from merlin.core.dispatch import make_df
from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.faiss import QueryFaiss, setup_faiss
Expand Down Expand Up @@ -57,9 +58,11 @@ def test_faiss_in_triton_executor_model(tmpdir):
)

faiss_path = tmpdir / "faiss.index"
item_ids = np.arange(0, 100).reshape(-1, 1)
item_embeddings = np.ascontiguousarray(np.random.rand(100, 128))
setup_faiss(np.concatenate((item_ids, item_embeddings), axis=1), faiss_path)
item_ids = np.arange(0, 100)
item_embeddings = np.random.rand(100, 128)
# cannot turn a list column in cudf directly to numpy so must delegate to pandas as bridge
df = make_df({"item_id": item_ids, "embedding": item_embeddings.tolist()}, device="cpu")
setup_faiss(df, faiss_path)

request_schema = Schema(
[
Expand Down

0 comments on commit 57b8211

Please sign in to comment.