-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add shim to adapt the lightfm library (#219)
Similar to #218, this change adds an adaptor to adapt lightfm to work with the high level merlin models api
- Loading branch information
Showing
8 changed files
with
213 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import multiprocessing | ||
|
||
import lightfm | ||
import lightfm.evaluation | ||
|
||
from merlin.io import Dataset | ||
from merlin.models.utils.dataset import dataset_to_coo | ||
|
||
|
||
class LightFM: | ||
""" | ||
This class adapts a model from lightfm to work with the high level merlin-models api | ||
Example usage:: | ||
# Get the movielens dataset | ||
from merlin.models.utils.data_etl_utils import get_movielens | ||
train, valid = get_movielens() | ||
# Train a WARP model with lightfm using the merlin movielens dataset | ||
from merlin.models.lightfm import LightFM | ||
model = LightFM(learning_rate=0.05, loss="warp") | ||
model.fit(train) | ||
# evaluate the model given the validation set | ||
print(model.evaluate(valid)) | ||
""" | ||
|
||
def __init__(self, *args, epochs=10, num_threads=0, **kwargs): | ||
self.lightfm_model = lightfm.LightFM(*args, **kwargs) | ||
self.epochs = epochs | ||
self.num_threads = num_threads or multiprocessing.cpu_count() | ||
|
||
def fit(self, train: Dataset): | ||
"""Trains the lightfm model | ||
Parameters | ||
---------- | ||
train : merlin.io.Dataset | ||
The training dataset to use to fit the model. We will use the the column tagged | ||
merlin.schema.Tags.ITEM_ID as the item , and merlin.schema.Tags.USER_ID as the userid. | ||
If there is a column tagged as Tags.TARGET we will also use that for the values, | ||
otherwise will be set to 1 | ||
""" | ||
data = dataset_to_coo(train).tocsr() | ||
self.lightfm_model.fit(data, epochs=self.epochs, num_threads=self.num_threads) | ||
self.train_data = data | ||
|
||
def evaluate(self, test_dataset: Dataset, k=10): | ||
"""Evaluates the model | ||
This function evalutes using a variety of ranking metrics, and returns | ||
a dictionary of {metric_name: value}. | ||
Parameters | ||
---------- | ||
test_dataset : merlin.io.Dataset | ||
The validation dataset to evaluate | ||
k : int | ||
How many items to return per prediction | ||
""" | ||
|
||
test = dataset_to_coo(test_dataset).tocsr() | ||
|
||
# lightfm needs the test set to have the same dimensionality as the train set | ||
test.resize(self.train_data.shape) | ||
|
||
precision = lightfm.evaluation.precision_at_k( | ||
self.lightfm_model, test, self.train_data, k=k, num_threads=self.num_threads | ||
).mean() | ||
auc = lightfm.evaluation.auc_score( | ||
self.lightfm_model, test, self.train_data, k=k, num_threads=self.num_threads | ||
).mean() | ||
return {f"precisions@{k}": precision, f"auc@{k}": auc} | ||
|
||
def predict(self, dataset: Dataset, k=10): | ||
"""Generate predictions from the dataset | ||
Parameters | ||
---------- | ||
test_dataset : merlin.io.Dataset | ||
k: int | ||
The number of recommendations to generate for each user | ||
""" | ||
data = dataset_to_coo(dataset) | ||
return self.lightfm_model.predict(data.row, data.col) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from scipy.sparse import coo_matrix | ||
|
||
from merlin.io import Dataset | ||
from merlin.schema import Tags | ||
|
||
|
||
def dataset_to_coo(dataset: Dataset): | ||
"""Converts a merlin.io.Dataset object to a scipy coo matrix""" | ||
user_id_column = dataset.schema.select_by_tag(Tags.USER_ID).first.name | ||
item_id_column = dataset.schema.select_by_tag(Tags.ITEM_ID).first.name | ||
|
||
columns = [user_id_column, item_id_column] | ||
target_column = None | ||
target = dataset.schema.select_by_tag(Tags.TARGET) | ||
|
||
if len(target) > 1: | ||
raise ValueError( | ||
"Found more than one column tagged Tags.TARGET in the dataset schema." | ||
f" Expected a single target column but found {target.column_names}" | ||
) | ||
|
||
elif len(target) == 1: | ||
target_column = target.first.name | ||
columns.append(target_column) | ||
|
||
df = dataset.to_ddf()[columns].compute(scheduler="synchronous") | ||
|
||
userids = _to_numpy(df[user_id_column]) | ||
itemids = _to_numpy(df[item_id_column]) | ||
targets = _to_numpy(df[target_column]) if target_column else np.ones(len(userids)) | ||
return coo_matrix((targets.astype("float32"), (userids, itemids))) | ||
|
||
|
||
def _to_numpy(series): | ||
"""converts a pandas or cudf series to a numpy array""" | ||
if isinstance(series, pd.Series): | ||
return series.values | ||
else: | ||
return series.values_host |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
lightfm>=1.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import pytest | ||
|
||
pytest.importorskip("lightfm") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from merlin.io import Dataset | ||
from merlin.models.data.synthetic import SyntheticData | ||
from merlin.models.lightfm import LightFM | ||
from merlin.schema import Tags | ||
|
||
|
||
def test_warp(music_streaming_data: SyntheticData): | ||
music_streaming_data._schema = music_streaming_data.schema.remove_by_tag(Tags.TARGET) | ||
dataset = Dataset(music_streaming_data.dataframe, schema=music_streaming_data.schema) | ||
|
||
model = LightFM(learning_rate=0.05, loss="warp", epochs=10) | ||
model.fit(dataset) | ||
|
||
model.predict(dataset) |