From 5da24596b6b2f409d3f6fc0659db721679ff8270 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 1 Nov 2023 23:58:30 +0800 Subject: [PATCH] support streamlit-feedback --- README.md | 1 + example.py | 21 +++++++++++++++++++ setup.py | 3 ++- streamlit_chatbox/__init__.py | 2 +- streamlit_chatbox/messages.py | 39 +++++++++++++++++++++++++++++++++-- 5 files changed, 62 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index cf9dbe3..cca7f52 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,7 @@ if show_history: - [x] streaming output message - [x] show message in expander - [ ] style the output message + - [x] feedback by user - [x] export & import chat history - [x] export to markdown diff --git a/example.py b/example.py index 7ffe8a4..6a644ec 100644 --- a/example.py +++ b/example.py @@ -31,6 +31,22 @@ chat_box.init_session() chat_box.output_messages() +def on_feedback( + feedback, + chat_history_id: str = "", + history_index: int = -1, +): + reason = feedback["text"] + score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) # convert emoji to integer + # do something + st.session_state["need_rerun"] = True + + +feedback_kwargs = { + "feedback_type": "thumbs", + "optional_text_label": "欢迎反馈您打分的理由", +} + if query := st.chat_input('input your question here'): chat_box.user_say(query) if streaming: @@ -51,6 +67,11 @@ # update the element without focus chat_box.update_msg(text, element_index=0, streaming=False, state="complete") chat_box.update_msg("\n\n".join(docs), element_index=1, streaming=False, state="complete") + chat_history_id = "some id" + chat_box.show_feedback(**feedback_kwargs, + key=chat_history_id, + on_submit=on_feedback, + kwargs={"chat_history_id": chat_history_id, "history_index": len(chat_box.history) - 1}) else: text, docs = llm.chat(query) chat_box.ai_say( diff --git a/setup.py b/setup.py index d47ee8e..33c387e 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ def readme(): setuptools.setup( name='streamlit-chatbox', - version='1.1.10', + version='1.1.11', author='liunux', author_email='liunux@qq.com', description='A chat box and some helpful tools used to build chatbot app with streamlit', @@ -23,5 +23,6 @@ def readme(): install_requires=[ 'streamlit>=1.26.0', 'simplejson', + 'streamlit-feedback', ] ) diff --git a/streamlit_chatbox/__init__.py b/streamlit_chatbox/__init__.py index 92d6035..bd9f124 100644 --- a/streamlit_chatbox/__init__.py +++ b/streamlit_chatbox/__init__.py @@ -4,7 +4,7 @@ from .messages import * -__version__ = "1.1.9" +__version__ = "1.1.11" __all__ = [ diff --git a/streamlit_chatbox/messages.py b/streamlit_chatbox/messages.py index aee1fcf..bdc8503 100644 --- a/streamlit_chatbox/messages.py +++ b/streamlit_chatbox/messages.py @@ -1,4 +1,5 @@ from streamlit_chatbox.elements import * +from streamlit_feedback import streamlit_feedback import time import inspect import simplejson as json @@ -15,6 +16,12 @@ # - if in expander: element is rendered in a st.empty in the st.status +POSSIBLE_SCORES = { + "thumbs": ["👍", "👎"], + "faces": ["😞", "🙁", "😐", "🙂", "😀"], +} + + class ChatBox: def __init__( self, @@ -290,13 +297,32 @@ def ai_say( for element in elements: element(render_to=container) - self.history.append({"role": "assistant", "elements": elements, "metadata": metadata}) + self.history.append({"role": "assistant", "elements": elements, "metadata": metadata.copy()}) return elements + def show_feedback(self, history_index=-1, **kwargs): + ''' + render feedback component + ''' + with self._chat_containers[history_index]: + self.history[history_index]["metadata"]["feedback_kwargs"] = kwargs + return streamlit_feedback(**kwargs) + + def set_feedback(self, feedback: Dict, history_index=-1) -> int: + ''' + set the feedback state for msg with a index of history_index + return the index of streamlit_feedback's emoji score + ''' + self.history[history_index]["metadata"]["feedback"] = feedback + score = feedback.get("score") + for v in POSSIBLE_SCORES.values(): + if score in v: + return v.index(score) + def output_messages(self): self.init_session() self._chat_containers = [] - for msg in self.history: + for i, msg in enumerate(self.history): avatar = self._user_avatar if msg["role"] == "user" else self._assistant_avatar chat_ele = st.chat_message(msg["role"], avatar=avatar) container = chat_ele.container() @@ -304,6 +330,12 @@ def output_messages(self): for element in msg["elements"]: element(render_to=container) + feedback_kwargs = msg["metadata"].get("feedback_kwargs", {}) + if feedback_kwargs: + if feedback := msg["metadata"].get("feedback"): + feedback_kwargs["disable_with_score"] = feedback["score"] + self.show_feedback(history_index=i, **feedback_kwargs) + def update_msg( self, element: Union["OutputElement", str] = None, @@ -314,6 +346,7 @@ def update_msg( title: str = None, expanded: bool = None, state: bool = None, + metadata: Dict = {}, ) -> st._DeltaGenerator: self.init_session() if not self.history or not self.history[history_index]["elements"]: @@ -331,6 +364,8 @@ def update_msg( element.status_from(old_element) self.history[history_index]["elements"][element_index] = element + self.history[history_index]["metadata"].update(metadata) + dg = old_element.update_element( element, title=title,