Skip to content

Commit

Permalink
add deduct quota for llm node
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Mar 12, 2024
1 parent 4d7caa3 commit 5fe0d50
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion api/core/workflow/nodes/llm/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory
Expand All @@ -21,6 +22,7 @@
from core.workflow.nodes.llm.entities import LLMNodeData
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus


Expand Down Expand Up @@ -144,10 +146,15 @@ def _invoke_llm(self, node_data: LLMNodeData,
)

# handle invoke result
return self._handle_invoke_result(
text, usage = self._handle_invoke_result(
invoke_result=invoke_result
)

# deduct quota
self._deduct_llm_quota(model_instance=model_instance, usage=usage)

return text, usage

def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
Expand Down Expand Up @@ -373,6 +380,53 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData,

return prompt_messages, stop

def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""
Deduct LLM quota
:param model_instance: model instance
:param usage: usage
:return:
"""
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration

if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return

system_configuration = provider_configuration.system_configuration

quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit

if quota_configuration.quota_limit == -1:
return

break

used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1

if 'gpt-4' in model_instance.model:
used_quota = 20
else:
used_quota = 1

if used_quota is not None:
db.session.query(Provider).filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == model_instance.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
db.session.commit()

@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Expand Down

0 comments on commit 5fe0d50

Please sign in to comment.