diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 49c9cba0828351..ae98f91a889eee 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -94,7 +94,7 @@ def init(self): if not self.conversation: self.is_new_conversation = True self.conversation = Conversation( - app_id=self.app_model_config.app_id, + app_id=self.app.id, app_model_config_id=self.app_model_config.id, model_provider=self.provider_name, model_id=self.model_name, @@ -115,7 +115,7 @@ def init(self): db.session.commit() self.message = Message( - app_id=self.app_model_config.app_id, + app_id=self.app.id, model_provider=self.provider_name, model_id=self.model_name, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, diff --git a/api/core/model_providers/models/llm/openai_model.py b/api/core/model_providers/models/llm/openai_model.py index d88d3af555f3f4..d65efc96418dcc 100644 --- a/api/core/model_providers/models/llm/openai_model.py +++ b/api/core/model_providers/models/llm/openai_model.py @@ -106,13 +106,7 @@ def _run(self, messages: List[PromptMessage], raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") prompts = self._get_prompt_from_messages(messages) - - try: - return self._client.generate([prompts], stop, callbacks) - finally: - thread_context = api_requestor._thread_context - if hasattr(thread_context, "session") and thread_context.session: - thread_context.session.close() + return self._client.generate([prompts], stop, callbacks) def get_num_tokens(self, messages: List[PromptMessage]) -> int: """ diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 252c85e93ffc65..d8ffd02ed4e211 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -155,7 +155,7 @@ def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any, generate_worker_thread.start() # wait for 10 minutes to close the thread - cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id) + cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id) return cls.compact_response(pubsub, streaming) @@ -210,25 +210,26 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App db.session.commit() @classmethod - def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread: + def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread: # wait for 10 minutes to close the thread timeout = 600 def close_pubsub(): - sleep_iterations = 0 - while sleep_iterations < timeout and worker_thread.is_alive(): - if sleep_iterations > 0 and sleep_iterations % 10 == 0: - PubHandler.ping(user, generate_task_id) + with flask_app.app_context(): + sleep_iterations = 0 + while sleep_iterations < timeout and worker_thread.is_alive(): + if sleep_iterations > 0 and sleep_iterations % 10 == 0: + PubHandler.ping(user, generate_task_id) - time.sleep(1) - sleep_iterations += 1 + time.sleep(1) + sleep_iterations += 1 - if worker_thread.is_alive(): - PubHandler.stop(user, generate_task_id) - try: - pubsub.close() - except: - pass + if worker_thread.is_alive(): + PubHandler.stop(user, generate_task_id) + try: + pubsub.close() + except: + pass countdown_thread = threading.Thread(target=close_pubsub) countdown_thread.start() @@ -288,7 +289,7 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], generate_worker_thread.start() - cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id) + cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id) return cls.compact_response(pubsub, streaming) @@ -313,15 +314,14 @@ def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError) as e: - db.session.rollback() PubHandler.pub_error(user, generate_task_id, e) except LLMAuthorizationError: - db.session.rollback() PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided')) except Exception as e: - db.session.rollback() logging.exception("Unknown Error in completion") PubHandler.pub_error(user, generate_task_id, e) + finally: + db.session.commit() @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):