Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Dec 28, 2023
2 parents 35e4f4b + 3087fc5 commit e030458
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 11 deletions.
16 changes: 11 additions & 5 deletions api/core/app_runner/basic_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity,
if app_orchestration_config.dataset:
context = self.retrieve_dataset_context(
tenant_id=app_record.tenant_id,
app_mode=app_record.mode,
app_record=app_record,
queue_manager=queue_manager,
model_config=app_orchestration_config.model_config,
show_retrieve_source=app_orchestration_config.show_retrieve_source,
Expand Down Expand Up @@ -283,7 +283,7 @@ def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
)

def retrieve_dataset_context(self, tenant_id: str,
app_mode: str,
app_record: App,
queue_manager: ApplicationQueueManager,
model_config: ModelConfigEntity,
dataset_config: DatasetEntity,
Expand All @@ -297,7 +297,7 @@ def retrieve_dataset_context(self, tenant_id: str,
"""
Retrieve dataset context
:param tenant_id: tenant id
:param app_mode: app mode
:param app_record: app record
:param queue_manager: queue manager
:param model_config: model config
:param dataset_config: dataset config
Expand All @@ -310,9 +310,15 @@ def retrieve_dataset_context(self, tenant_id: str,
:param memory: memory
:return:
"""
hit_callback = DatasetIndexToolCallbackHandler(queue_manager, message.id, user_id)
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
user_id,
invoke_from
)

if (app_mode == AppMode.COMPLETION.value and dataset_config
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
and dataset_config.retrieve_config.query_variable):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")

Expand Down
26 changes: 24 additions & 2 deletions api/core/callback_handler/index_tool_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,42 @@
from langchain.schema import Document

from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from extensions.ext_database import db
from models.dataset import DocumentSegment
from models.dataset import DocumentSegment, DatasetQuery
from models.model import DatasetRetrieverResource


class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""

def __init__(self, queue_manager: ApplicationQueueManager,
app_id: str,
message_id: str,
user_id: str) -> None:
user_id: str,
invoke_from: InvokeFrom) -> None:
self._queue_manager = queue_manager
self._app_id = app_id
self._message_id = message_id
self._user_id = user_id
self._invoke_from = invoke_from

def on_query(self, query: str, dataset_id: str) -> None:
"""
Handle query.
"""
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source_app_id=self._app_id,
created_by_role=('account'
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=self._user_id
)

db.session.add(dataset_query)
db.session.commit()

def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end."""
Expand Down
4 changes: 3 additions & 1 deletion api/core/features/agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,10 @@ def to_dataset_retriever_tool(self, tool_config: dict,

hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id
user_id=self.user_id,
invoke_from=invoke_from
)

# get dataset from dataset id
Expand Down
10 changes: 8 additions & 2 deletions api/core/tool/dataset_multi_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def _run(self, query: str) -> str:
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'all_documents': all_documents
'all_documents': all_documents,
'hit_callbacks': self.hit_callbacks
})
threads.append(retrieval_thread)
retrieval_thread.start()
Expand Down Expand Up @@ -154,7 +155,8 @@ def _run(self, query: str) -> str:
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()

def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List):
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List,
hit_callbacks: List[DatasetIndexToolCallbackHandler]):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Expand All @@ -163,6 +165,10 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document

if not dataset:
return []

for hit_callback in hit_callbacks:
hit_callback.on_query(query, dataset.id)

# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model

Expand Down
5 changes: 4 additions & 1 deletion api/core/tool/dataset_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ def _run(self, query: str) -> str:

if not dataset:
return ''

for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)

# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model

if dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
Expand Down

0 comments on commit e030458

Please sign in to comment.