diff --git a/cherche/__version__.py b/cherche/__version__.py index 25e7d8c..0579ac3 100644 --- a/cherche/__version__.py +++ b/cherche/__version__.py @@ -1,3 +1,3 @@ -VERSION = (2, 0, 3) +VERSION = (2, 0, 4) __version__ = ".".join(map(str, VERSION)) diff --git a/cherche/retrieve/tfidf.py b/cherche/retrieve/tfidf.py index 7024f3d..d99a6c4 100644 --- a/cherche/retrieve/tfidf.py +++ b/cherche/retrieve/tfidf.py @@ -4,7 +4,7 @@ import numpy as np import tqdm -from scipy.sparse import csc_matrix +from scipy.sparse import csc_matrix, csr_matrix from sklearn.feature_extraction.text import TfidfVectorizer from ..utils import yield_batch @@ -42,6 +42,7 @@ class TfIdf(Retriever): ... ] >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) + >>> retriever = retriever.to_csr() >>> retriever TfIdf retriever @@ -80,6 +81,7 @@ def __init__( tfidf: TfidfVectorizer = None, k: typing.Optional[int] = None, batch_size: int = 1024, + fit: bool = True, ) -> None: super().__init__(key=key, on=on, k=k, batch_size=batch_size) @@ -91,8 +93,10 @@ def __init__( self.documents = [{self.key: document[self.key]} for document in documents] + method = self.tfidf.fit_transform if fit else self.tfidf.transform + self.matrix = csc_matrix( - self.tfidf.fit_transform( + method( [ " ".join([doc.get(field, "") for field in self.on]) for doc in documents @@ -102,6 +106,27 @@ def __init__( self.k = len(self.documents) if k is None else k self.n = len(self.documents) + self.is_csc = True + + def to_csr(self) -> "TfIdf": + """Convert the matrix to a csr matrix. + + Speed-up if you want to retrieve documents from multiples queries. + """ + if self.is_csc: + self.matrix = csr_matrix(self.matrix.T) + self.is_csc = False + return self + + def to_csc(self) -> "TfIdf": + """Convert the matrix to a csc matrix. + + Speed-up if you want to retrieve documents from a sinle query. + """ + if not self.is_csc: + self.matrix = csc_matrix(self.matrix.T) + self.is_csc = True + return self def top_k_by_partition( self, similarities: np.ndarray, k: int @@ -163,7 +188,12 @@ def __call__( desc=f"{self.__class__.__name__} retriever", tqdm_bar=tqdm_bar, ): - similarities = self.tfidf.transform(batch).dot(self.matrix).toarray() + if self.is_csc: + similarities = self.tfidf.transform(batch).dot(self.matrix).toarray() + else: + similarities = ( + self.matrix.dot(self.tfidf.transform(batch).T).toarray().T + ) batch_match, batch_similarities = self.top_k_by_partition( similarities=similarities, k=k