From 265cf49581eb3865a8d9f9501a5d1fa8e810b44c Mon Sep 17 00:00:00 2001 From: Roman Romanov Date: Mon, 19 Aug 2024 10:23:51 +0300 Subject: [PATCH 1/4] Add example of search using multi-modal embeddings --- examples/image_search/app.py | 111 ++++++++++++++++++++++++++ examples/image_search/attachment.py | 32 ++++++++ examples/image_search/embeddings.py | 57 +++++++++++++ examples/image_search/vector_store.py | 20 +++++ 4 files changed, 220 insertions(+) create mode 100644 examples/image_search/app.py create mode 100644 examples/image_search/attachment.py create mode 100644 examples/image_search/embeddings.py create mode 100644 examples/image_search/vector_store.py diff --git a/examples/image_search/app.py b/examples/image_search/app.py new file mode 100644 index 0000000..ff2ad34 --- /dev/null +++ b/examples/image_search/app.py @@ -0,0 +1,111 @@ +""" +A simple application,that search over attached images by text query, +using multi-modal embeddings for search +""" + +import os +from uuid import uuid4 + +import uvicorn +from langchain.callbacks.base import AsyncCallbackHandler + +from aidial_sdk import DIALApp +from aidial_sdk import HTTPException as DIALException +from aidial_sdk.chat_completion import ChatCompletion, Choice, Request, Response +from examples.image_search.attachment import get_image_attachments +from examples.image_search.embeddings import ImageDialEmbeddings +from examples.image_search.vector_store import DialImageVectorStore + + +def get_env(name: str) -> str: + value = os.getenv(name) + if value is None: + raise ValueError(f"Please provide {name!r} environment variable") + return value + + +DIAL_URL = get_env("DIAL_URL") +EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL", "multimodalembedding@001") +EMBEDDINGS_DIMENSIONS = int(os.getenv("EMBEDDINGS_DIMENSIONS") or "1408") + + +class CustomCallbackHandler(AsyncCallbackHandler): + def __init__(self, choice: Choice): + self._choice = choice + + async def on_llm_new_token(self, token: str, *args, **kwargs) -> None: + self._choice.append_content(token) + + +class ImageSearchApplication(ChatCompletion): + async def chat_completion( + self, request: Request, response: Response + ) -> None: + with response.create_single_choice() as choice: + message = request.messages[-1] + user_query = message.content + + if not user_query: + raise DIALException( + message="Please provide search query", status_code=400 + ) + + image_attachments = get_image_attachments(request.messages) + if not image_attachments: + msg = "No attachment with DIAL Storage URL was found" + raise DIALException( + status_code=422, + message=msg, + display_message=msg, + ) + # Create a new local vector store to store image embeddings + vector_store = DialImageVectorStore( + collection_name=str(uuid4()), + embedding_function=ImageDialEmbeddings( + dial_url=DIAL_URL, + embeddings_model=EMBEDDINGS_MODEL, + dimensions=EMBEDDINGS_DIMENSIONS, + ), + ) + # Show user that embeddings of images are being calculated + with choice.create_stage("Calculating image embeddings"): + # For simplicity of example let's take only images, + # that are uploaded to DIAL Storage already + await vector_store.aadd_images( + uris=[att.url for att in image_attachments if att.url], + metadatas=[ + {"url": att.url, "type": att.type, "title": att.title} + for att in image_attachments + if att.url + ], + ) + + # Show user that the search is being performed + with choice.create_stage("Searching for most relevant image"): + search_result = await vector_store.asimilarity_search( + query=user_query, k=1 + ) + + if len(search_result) == 0: + msg = "No relevant image found" + raise DIALException( + status_code=404, + message=msg, + display_message=msg, + ) + + top_result = search_result[0] + choice.add_attachment( + url=top_result.metadata["url"], + title=top_result.metadata["title"], + type=top_result.metadata["type"], + ) + vector_store.delete_collection() + + +app = DIALApp(DIAL_URL, propagate_auth_headers=True) +app.add_chat_completion("image-search", ImageSearchApplication()) + + +if __name__ == "__main__": + uvicorn.run(app, port=5000) diff --git a/examples/image_search/attachment.py b/examples/image_search/attachment.py new file mode 100644 index 0000000..7fb022c --- /dev/null +++ b/examples/image_search/attachment.py @@ -0,0 +1,32 @@ +from typing import List, Optional + +from aidial_sdk.chat_completion import Message +from aidial_sdk.chat_completion.request import Attachment + +DEFAULT_IMAGE_TYPES = ["image/jpeg", "image/png"] + + +def get_image_attachments( + messages: List[Message], image_types: Optional[List[str]] = None +) -> List[Attachment]: + if image_types is None: + image_types = DEFAULT_IMAGE_TYPES + + attachments = [] + for message in messages: + if ( + message.custom_content is not None + and message.custom_content.attachments is not None + ): + attachments = message.custom_content.attachments + for attachment in attachments: + if ( + # For simplicity of example let's take only images, + # that are uploaded to DIAL Storage already + attachment.url + and attachment.type + and attachment.type in image_types + ): + attachments.append(attachment) + + return attachments diff --git a/examples/image_search/embeddings.py b/examples/image_search/embeddings.py new file mode 100644 index 0000000..a49f6a0 --- /dev/null +++ b/examples/image_search/embeddings.py @@ -0,0 +1,57 @@ +from typing import List + +import httpx +from langchain_core.embeddings import Embeddings + + +class ImageDialEmbeddings(Embeddings): + def __init__( + self, + dial_url: str, + embeddings_model: str, + dimensions: int, + ) -> None: + self._dial_url = dial_url + self._embeddings_url = ( + f"{self._dial_url}/openai/deployments/{embeddings_model}/embeddings" + ) + self._dimensions = dimensions + self._client = httpx.Client() + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError( + "This embeddings should not be used with text documents" + ) + + def embed_query(self, text: str) -> List[float]: + # Auth headers are propagated by the DIALApp + response = self._client.post( + self._embeddings_url, + json={"input": [text], "dimensions": self._dimensions}, + ) + data = response.json() + assert data.get("data") and len(data.get("data")) == 1 + return data.get("data")[0].get("embedding") + + def embed_image(self, uris: List[str]) -> List[List[float]]: + result = [] + for uri in uris: + # Auth headers are propagated by the DIALApp + response = self._client.post( + self._embeddings_url, + json={ + "input": [], + "dimensions": self._dimensions, + "custom_input": [ + { + "type": "image/png", + "url": uri, + } + ], + }, + ) + data = response.json() + assert data.get("data") and len(data.get("data")) == 1 + result.append(data.get("data")[0].get("embedding")) + assert len(result) == len(uris) + return result diff --git a/examples/image_search/vector_store.py b/examples/image_search/vector_store.py new file mode 100644 index 0000000..214dd52 --- /dev/null +++ b/examples/image_search/vector_store.py @@ -0,0 +1,20 @@ +from typing import List, Optional + +from langchain_community.vectorstores import Chroma +from langchain_core.runnables.config import run_in_executor + + +class DialImageVectorStore(Chroma): + def encode_image(self, uri: str) -> str: + """ + Overload of Chroma encode_image method, that does not download image content + """ + return uri + + async def aadd_images( + self, uris: List[str], metadatas: Optional[List[dict]] = None + ): + """ + Async version of add_images, that is present in Chroma + """ + return await run_in_executor(None, self.add_images, uris, metadatas) From 3bf6b0f583891a4f2ec3806ac102d06347bb0c97 Mon Sep 17 00:00:00 2001 From: Roman Romanov Date: Mon, 19 Aug 2024 10:42:36 +0300 Subject: [PATCH 2/4] add requirements --- examples/image_search/app.py | 11 +---------- examples/image_search/requirements.txt | 5 +++++ 2 files changed, 6 insertions(+), 10 deletions(-) create mode 100644 examples/image_search/requirements.txt diff --git a/examples/image_search/app.py b/examples/image_search/app.py index ff2ad34..47d7af6 100644 --- a/examples/image_search/app.py +++ b/examples/image_search/app.py @@ -7,11 +7,10 @@ from uuid import uuid4 import uvicorn -from langchain.callbacks.base import AsyncCallbackHandler from aidial_sdk import DIALApp from aidial_sdk import HTTPException as DIALException -from aidial_sdk.chat_completion import ChatCompletion, Choice, Request, Response +from aidial_sdk.chat_completion import ChatCompletion, Request, Response from examples.image_search.attachment import get_image_attachments from examples.image_search.embeddings import ImageDialEmbeddings from examples.image_search.vector_store import DialImageVectorStore @@ -29,14 +28,6 @@ def get_env(name: str) -> str: EMBEDDINGS_DIMENSIONS = int(os.getenv("EMBEDDINGS_DIMENSIONS") or "1408") -class CustomCallbackHandler(AsyncCallbackHandler): - def __init__(self, choice: Choice): - self._choice = choice - - async def on_llm_new_token(self, token: str, *args, **kwargs) -> None: - self._choice.append_content(token) - - class ImageSearchApplication(ChatCompletion): async def chat_completion( self, request: Request, response: Response diff --git a/examples/image_search/requirements.txt b/examples/image_search/requirements.txt new file mode 100644 index 0000000..bcc104e --- /dev/null +++ b/examples/image_search/requirements.txt @@ -0,0 +1,5 @@ +aidial-sdk>=0.10 +langchain-core==0.2.9 +langchain-community==0.2.9 +chromadb==0.5.4 +uvicorn==0.30.1 \ No newline at end of file From a3d0fad58ace2a620224f1cf7646b1596a464216 Mon Sep 17 00:00:00 2001 From: Roman Romanov Date: Mon, 19 Aug 2024 10:45:24 +0300 Subject: [PATCH 3/4] Fix pipeline --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4183415..2d62d1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,8 @@ exclude = [ ".pytest_cache", "**/__pycache__", "build", - "examples/langchain_rag" + "examples/langchain_rag", + "examples/image_search" ] [tool.black] From 4f79ecf460570966b5caa752a955e0bc6097dcbc Mon Sep 17 00:00:00 2001 From: Roman Romanov Date: Mon, 19 Aug 2024 11:06:07 +0300 Subject: [PATCH 4/4] refactor imports --- examples/image_search/app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/image_search/app.py b/examples/image_search/app.py index 47d7af6..488079d 100644 --- a/examples/image_search/app.py +++ b/examples/image_search/app.py @@ -7,13 +7,13 @@ from uuid import uuid4 import uvicorn +from attachment import get_image_attachments +from embeddings import ImageDialEmbeddings +from vector_store import DialImageVectorStore from aidial_sdk import DIALApp from aidial_sdk import HTTPException as DIALException from aidial_sdk.chat_completion import ChatCompletion, Request, Response -from examples.image_search.attachment import get_image_attachments -from examples.image_search.embeddings import ImageDialEmbeddings -from examples.image_search.vector_store import DialImageVectorStore def get_env(name: str) -> str: