diff --git a/examples/image_search/app.py b/examples/image_search/app.py new file mode 100644 index 0000000..488079d --- /dev/null +++ b/examples/image_search/app.py @@ -0,0 +1,102 @@ +""" +A simple application,that search over attached images by text query, +using multi-modal embeddings for search +""" + +import os +from uuid import uuid4 + +import uvicorn +from attachment import get_image_attachments +from embeddings import ImageDialEmbeddings +from vector_store import DialImageVectorStore + +from aidial_sdk import DIALApp +from aidial_sdk import HTTPException as DIALException +from aidial_sdk.chat_completion import ChatCompletion, Request, Response + + +def get_env(name: str) -> str: + value = os.getenv(name) + if value is None: + raise ValueError(f"Please provide {name!r} environment variable") + return value + + +DIAL_URL = get_env("DIAL_URL") +EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL", "multimodalembedding@001") +EMBEDDINGS_DIMENSIONS = int(os.getenv("EMBEDDINGS_DIMENSIONS") or "1408") + + +class ImageSearchApplication(ChatCompletion): + async def chat_completion( + self, request: Request, response: Response + ) -> None: + with response.create_single_choice() as choice: + message = request.messages[-1] + user_query = message.content + + if not user_query: + raise DIALException( + message="Please provide search query", status_code=400 + ) + + image_attachments = get_image_attachments(request.messages) + if not image_attachments: + msg = "No attachment with DIAL Storage URL was found" + raise DIALException( + status_code=422, + message=msg, + display_message=msg, + ) + # Create a new local vector store to store image embeddings + vector_store = DialImageVectorStore( + collection_name=str(uuid4()), + embedding_function=ImageDialEmbeddings( + dial_url=DIAL_URL, + embeddings_model=EMBEDDINGS_MODEL, + dimensions=EMBEDDINGS_DIMENSIONS, + ), + ) + # Show user that embeddings of images are being calculated + with choice.create_stage("Calculating image embeddings"): + # For simplicity of example let's take only images, + # that are uploaded to DIAL Storage already + await vector_store.aadd_images( + uris=[att.url for att in image_attachments if att.url], + metadatas=[ + {"url": att.url, "type": att.type, "title": att.title} + for att in image_attachments + if att.url + ], + ) + + # Show user that the search is being performed + with choice.create_stage("Searching for most relevant image"): + search_result = await vector_store.asimilarity_search( + query=user_query, k=1 + ) + + if len(search_result) == 0: + msg = "No relevant image found" + raise DIALException( + status_code=404, + message=msg, + display_message=msg, + ) + + top_result = search_result[0] + choice.add_attachment( + url=top_result.metadata["url"], + title=top_result.metadata["title"], + type=top_result.metadata["type"], + ) + vector_store.delete_collection() + + +app = DIALApp(DIAL_URL, propagate_auth_headers=True) +app.add_chat_completion("image-search", ImageSearchApplication()) + + +if __name__ == "__main__": + uvicorn.run(app, port=5000) diff --git a/examples/image_search/attachment.py b/examples/image_search/attachment.py new file mode 100644 index 0000000..7fb022c --- /dev/null +++ b/examples/image_search/attachment.py @@ -0,0 +1,32 @@ +from typing import List, Optional + +from aidial_sdk.chat_completion import Message +from aidial_sdk.chat_completion.request import Attachment + +DEFAULT_IMAGE_TYPES = ["image/jpeg", "image/png"] + + +def get_image_attachments( + messages: List[Message], image_types: Optional[List[str]] = None +) -> List[Attachment]: + if image_types is None: + image_types = DEFAULT_IMAGE_TYPES + + attachments = [] + for message in messages: + if ( + message.custom_content is not None + and message.custom_content.attachments is not None + ): + attachments = message.custom_content.attachments + for attachment in attachments: + if ( + # For simplicity of example let's take only images, + # that are uploaded to DIAL Storage already + attachment.url + and attachment.type + and attachment.type in image_types + ): + attachments.append(attachment) + + return attachments diff --git a/examples/image_search/embeddings.py b/examples/image_search/embeddings.py new file mode 100644 index 0000000..a49f6a0 --- /dev/null +++ b/examples/image_search/embeddings.py @@ -0,0 +1,57 @@ +from typing import List + +import httpx +from langchain_core.embeddings import Embeddings + + +class ImageDialEmbeddings(Embeddings): + def __init__( + self, + dial_url: str, + embeddings_model: str, + dimensions: int, + ) -> None: + self._dial_url = dial_url + self._embeddings_url = ( + f"{self._dial_url}/openai/deployments/{embeddings_model}/embeddings" + ) + self._dimensions = dimensions + self._client = httpx.Client() + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError( + "This embeddings should not be used with text documents" + ) + + def embed_query(self, text: str) -> List[float]: + # Auth headers are propagated by the DIALApp + response = self._client.post( + self._embeddings_url, + json={"input": [text], "dimensions": self._dimensions}, + ) + data = response.json() + assert data.get("data") and len(data.get("data")) == 1 + return data.get("data")[0].get("embedding") + + def embed_image(self, uris: List[str]) -> List[List[float]]: + result = [] + for uri in uris: + # Auth headers are propagated by the DIALApp + response = self._client.post( + self._embeddings_url, + json={ + "input": [], + "dimensions": self._dimensions, + "custom_input": [ + { + "type": "image/png", + "url": uri, + } + ], + }, + ) + data = response.json() + assert data.get("data") and len(data.get("data")) == 1 + result.append(data.get("data")[0].get("embedding")) + assert len(result) == len(uris) + return result diff --git a/examples/image_search/requirements.txt b/examples/image_search/requirements.txt new file mode 100644 index 0000000..bcc104e --- /dev/null +++ b/examples/image_search/requirements.txt @@ -0,0 +1,5 @@ +aidial-sdk>=0.10 +langchain-core==0.2.9 +langchain-community==0.2.9 +chromadb==0.5.4 +uvicorn==0.30.1 \ No newline at end of file diff --git a/examples/image_search/vector_store.py b/examples/image_search/vector_store.py new file mode 100644 index 0000000..214dd52 --- /dev/null +++ b/examples/image_search/vector_store.py @@ -0,0 +1,20 @@ +from typing import List, Optional + +from langchain_community.vectorstores import Chroma +from langchain_core.runnables.config import run_in_executor + + +class DialImageVectorStore(Chroma): + def encode_image(self, uri: str) -> str: + """ + Overload of Chroma encode_image method, that does not download image content + """ + return uri + + async def aadd_images( + self, uris: List[str], metadatas: Optional[List[dict]] = None + ): + """ + Async version of add_images, that is present in Chroma + """ + return await run_in_executor(None, self.add_images, uris, metadatas) diff --git a/pyproject.toml b/pyproject.toml index 4183415..2d62d1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,8 @@ exclude = [ ".pytest_cache", "**/__pycache__", "build", - "examples/langchain_rag" + "examples/langchain_rag", + "examples/image_search" ] [tool.black]