Skip to content

Commit

Permalink
fix: count down thread in completion db not commit (#1267)
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost authored Oct 2, 2023
1 parent 86a9dea commit 41d4c5b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 27 deletions.
4 changes: 2 additions & 2 deletions api/core/conversation_message_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def init(self):
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app_model_config.app_id,
app_id=self.app.id,
app_model_config_id=self.app_model_config.id,
model_provider=self.provider_name,
model_id=self.model_name,
Expand All @@ -115,7 +115,7 @@ def init(self):
db.session.commit()

self.message = Message(
app_id=self.app_model_config.app_id,
app_id=self.app.id,
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
Expand Down
8 changes: 1 addition & 7 deletions api/core/model_providers/models/llm/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,7 @@ def _run(self, messages: List[PromptMessage],
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")

prompts = self._get_prompt_from_messages(messages)

try:
return self._client.generate([prompts], stop, callbacks)
finally:
thread_context = api_requestor._thread_context
if hasattr(thread_context, "session") and thread_context.session:
thread_context.session.close()
return self._client.generate([prompts], stop, callbacks)

def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
Expand Down
36 changes: 18 additions & 18 deletions api/services/completion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
generate_worker_thread.start()

# wait for 10 minutes to close the thread
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)

return cls.compact_response(pubsub, streaming)

Expand Down Expand Up @@ -210,25 +210,26 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App
db.session.commit()

@classmethod
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread
timeout = 600

def close_pubsub():
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
with flask_app.app_context():
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)

time.sleep(1)
sleep_iterations += 1
time.sleep(1)
sleep_iterations += 1

if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except:
pass
if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except:
pass

countdown_thread = threading.Thread(target=close_pubsub)
countdown_thread.start()
Expand Down Expand Up @@ -288,7 +289,7 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],

generate_worker_thread.start()

cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)

return cls.compact_response(pubsub, streaming)

Expand All @@ -313,15 +314,14 @@ def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str,
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
db.session.rollback()
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.commit()

@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
Expand Down

0 comments on commit 41d4c5b

Please sign in to comment.