Skip to content

Commit

Permalink
Add sparse matrix utilities tfidf
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Sourty committed Sep 5, 2023
1 parent e54b86a commit e8d7c83
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cherche/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (2, 0, 3)
VERSION = (2, 0, 4)

__version__ = ".".join(map(str, VERSION))
36 changes: 33 additions & 3 deletions cherche/retrieve/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import tqdm
from scipy.sparse import csc_matrix
from scipy.sparse import csc_matrix, csr_matrix
from sklearn.feature_extraction.text import TfidfVectorizer

from ..utils import yield_batch
Expand Down Expand Up @@ -42,6 +42,7 @@ class TfIdf(Retriever):
... ]
>>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents)
>>> retriever = retriever.to_csr()
>>> retriever
TfIdf retriever
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
tfidf: TfidfVectorizer = None,
k: typing.Optional[int] = None,
batch_size: int = 1024,
fit: bool = True,
) -> None:
super().__init__(key=key, on=on, k=k, batch_size=batch_size)

Expand All @@ -91,8 +93,10 @@ def __init__(

self.documents = [{self.key: document[self.key]} for document in documents]

method = self.tfidf.fit_transform if fit else self.tfidf.transform

self.matrix = csc_matrix(
self.tfidf.fit_transform(
method(
[
" ".join([doc.get(field, "") for field in self.on])
for doc in documents
Expand All @@ -102,6 +106,27 @@ def __init__(

self.k = len(self.documents) if k is None else k
self.n = len(self.documents)
self.is_csc = True

def to_csr(self) -> "TfIdf":
"""Convert the matrix to a csr matrix.
Speed-up if you want to retrieve documents from multiples queries.
"""
if self.is_csc:
self.matrix = csr_matrix(self.matrix.T)
self.is_csc = False
return self

def to_csc(self) -> "TfIdf":
"""Convert the matrix to a csc matrix.
Speed-up if you want to retrieve documents from a sinle query.
"""
if not self.is_csc:
self.matrix = csc_matrix(self.matrix.T)
self.is_csc = True
return self

def top_k_by_partition(
self, similarities: np.ndarray, k: int
Expand Down Expand Up @@ -163,7 +188,12 @@ def __call__(
desc=f"{self.__class__.__name__} retriever",
tqdm_bar=tqdm_bar,
):
similarities = self.tfidf.transform(batch).dot(self.matrix).toarray()
if self.is_csc:
similarities = self.tfidf.transform(batch).dot(self.matrix).toarray()
else:
similarities = (
self.matrix.dot(self.tfidf.transform(batch).T).toarray().T
)

batch_match, batch_similarities = self.top_k_by_partition(
similarities=similarities, k=k
Expand Down

0 comments on commit e8d7c83

Please sign in to comment.