From 62bfbada626dd0a7adce31fdac21b832ab799312 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 22 Apr 2024 14:15:38 +0800 Subject: [PATCH] - fix type hint error with streamlit >= 1.33.0 (#8) - add ChatBox.change_chat_name to rename a chat conversation - maintain a context bound to chat conversation, it is like to be a sub session_state for every chat, context will get changed when you switch chat names. - user can save chat bound values by `ChatBox.context['x'] = 1` - values of widgets specified with a key can be saved to chat context with `ChatBox.context_from_session` and restored to st.session_state by `ChatBox.context_to_session` --- README.md | 50 +++++++++++++++++++-- example.py | 18 +++++--- setup.py | 2 +- streamlit_chatbox/elements.py | 11 ++--- streamlit_chatbox/messages.py | 85 +++++++++++++++++++++++++++++------ 5 files changed, 138 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index cca7f52..842adfa 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ It's basiclly a wrapper of streamlit officeial elements including the chat elemn - support streaming output. - support markdown/image/video/audio messages, and all streamlit elements could be supported by customized `OutputElement`. - output multiple messages at once, and make them collapsable. +- maintain session state context bound to chat conversation - export & import chat histories This make it easy to chat with langchain LLMs in streamlit. @@ -46,13 +47,21 @@ import simplejson as json llm = FakeLLM() chat_box = ChatBox() +chat_box.use_chat_name("chat1") # add a chat conversatoin + +def on_chat_change(): + chat_box.use_chat_name(st.session_state["chat_name"]) + chat_box.context_to_session() # restore widget values to st.session_state when chat name changed with st.sidebar: st.subheader('start to chat using streamlit') - streaming = st.checkbox('streaming', True) - in_expander = st.checkbox('show messages in expander', True) - show_history = st.checkbox('show history', False) + chat_name = st.selectbox("Chat Session:", ["default", "chat1"], key="chat_name", on_change=on_chat_change) + chat_box.use_chat_name(chat_name) + streaming = st.checkbox('streaming', key="streaming") + in_expander = st.checkbox('show messages in expander', key="in_expander") + show_history = st.checkbox('show session state', key="show_history") + chat_box.context_from_session(exclude=["chat_name"]) # save widget values to chat context st.divider() @@ -71,6 +80,22 @@ with st.sidebar: chat_box.init_session() chat_box.output_messages() +def on_feedback( + feedback, + chat_history_id: str = "", + history_index: int = -1, +): + reason = feedback["text"] + score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) # convert emoji to integer + # do something + st.session_state["need_rerun"] = True + + +feedback_kwargs = { + "feedback_type": "thumbs", + "optional_text_label": "wellcome to feedback", +} + if query := st.chat_input('input your question here'): chat_box.user_say(query) if streaming: @@ -91,6 +116,11 @@ if query := st.chat_input('input your question here'): # update the element without focus chat_box.update_msg(text, element_index=0, streaming=False, state="complete") chat_box.update_msg("\n\n".join(docs), element_index=1, streaming=False, state="complete") + chat_history_id = "some id" + chat_box.show_feedback(**feedback_kwargs, + key=chat_history_id, + on_submit=on_feedback, + kwargs={"chat_history_id": chat_history_id, "history_index": len(chat_box.history) - 1}) else: text, docs = llm.chat(query) chat_box.ai_say( @@ -156,7 +186,8 @@ if btns.button("clear history"): if show_history: - st.write(chat_box.history) + st.write(st.session_state) + ``` ## Todos @@ -189,3 +220,14 @@ if show_history: - [x] import json - [x] support output of langchain' Agent. +- [x] conext bound to chat + +# changelog + +## v1.1.12 +- fix type hint error with streamlit >= 1.33.0 (#8) +- add ChatBox.change_chat_name to rename a chat conversation +- maintain a context bound to chat conversation, it is like to be a sub session_state for every chat, context will get changed when you switch chat names. + - user can save chat bound values by `ChatBox.context['x'] = 1` + - values of widgets specified with a key can be saved to chat context with `ChatBox.context_from_session` and restored to st.session_state by `ChatBox.context_to_session` + diff --git a/example.py b/example.py index 6a644ec..d1034c6 100644 --- a/example.py +++ b/example.py @@ -6,13 +6,21 @@ llm = FakeLLM() chat_box = ChatBox() +chat_box.use_chat_name("chat1") # add a chat conversatoin + +def on_chat_change(): + chat_box.use_chat_name(st.session_state["chat_name"]) + chat_box.context_to_session() # restore widget values to st.session_state when chat name changed with st.sidebar: st.subheader('start to chat using streamlit') - streaming = st.checkbox('streaming', True) - in_expander = st.checkbox('show messages in expander', True) - show_history = st.checkbox('show history', False) + chat_name = st.selectbox("Chat Session:", ["default", "chat1"], key="chat_name", on_change=on_chat_change) + chat_box.use_chat_name(chat_name) + streaming = st.checkbox('streaming', key="streaming") + in_expander = st.checkbox('show messages in expander', key="in_expander") + show_history = st.checkbox('show session state', key="show_history") + chat_box.context_from_session(exclude=["chat_name"]) # save widget values to chat context st.divider() @@ -44,7 +52,7 @@ def on_feedback( feedback_kwargs = { "feedback_type": "thumbs", - "optional_text_label": "欢迎反馈您打分的理由", + "optional_text_label": "wellcome to feedback", } if query := st.chat_input('input your question here'): @@ -137,4 +145,4 @@ def on_feedback( if show_history: - st.write(chat_box.history) + st.write(st.session_state) diff --git a/setup.py b/setup.py index 33c387e..b049a3a 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ def readme(): setuptools.setup( name='streamlit-chatbox', - version='1.1.11', + version='1.1.12', author='liunux', author_email='liunux@qq.com', description='A chat box and some helpful tools used to build chatbot app with streamlit', diff --git a/streamlit_chatbox/elements.py b/streamlit_chatbox/elements.py index 213ac81..a415823 100644 --- a/streamlit_chatbox/elements.py +++ b/streamlit_chatbox/elements.py @@ -1,5 +1,6 @@ from typing import * import streamlit as st +from streamlit.delta_generator import DeltaGenerator # from pydantic import BaseModel, Field @@ -40,7 +41,7 @@ def _set_default_kwargs(self) -> None: for k, v in default.items(): self._kwargs.setdefault(k, v) - def __call__(self, render_to: st._DeltaGenerator=None) -> st._DeltaGenerator: + def __call__(self, render_to: DeltaGenerator=None) -> DeltaGenerator: # assert self._dg is None, "Every element can be rendered once only." render_to = render_to or st self._place_holder = render_to.empty() @@ -51,11 +52,11 @@ def __call__(self, render_to: st._DeltaGenerator=None) -> st._DeltaGenerator: return self._dg @property - def dg(self) -> st._DeltaGenerator: + def dg(self) -> DeltaGenerator: return self._dg @property - def place_holder(self) -> st._DeltaGenerator: + def place_holder(self) -> DeltaGenerator: return self._place_holder @property @@ -132,7 +133,7 @@ def from_dict(cls, d: Dict) -> "OutputElement": return factory_cls(**kwargs) - def __call__(self, render_to: st._DeltaGenerator=None, direct: bool=False) -> st._DeltaGenerator: + def __call__(self, render_to: DeltaGenerator=None, direct: bool=False) -> DeltaGenerator: if render_to is None: if self._place_holder is None: self._place_holder = st.empty() @@ -158,7 +159,7 @@ def update_element( title: str = None, expanded: bool = None, state: bool = None, - ) -> st._DeltaGenerator: + ) -> DeltaGenerator: assert self.place_holder is not None, f"You must render the element {self} before setting new element." attrs = {} if title is not None: diff --git a/streamlit_chatbox/messages.py b/streamlit_chatbox/messages.py index 7f6d150..8a98a92 100644 --- a/streamlit_chatbox/messages.py +++ b/streamlit_chatbox/messages.py @@ -22,6 +22,23 @@ } +class AttrDict(dict): + def __getattr__(self, key: str) -> Any: + try: + return self[key] + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key: str, value: Any) -> None: + self[key] = value + + def __delattr__(self, key: str) -> None: + try: + del self[key] + except KeyError: + raise AttributeError(key) + + class ChatBox: def __init__( self, @@ -54,25 +71,33 @@ def init_session(self, clear: bool =False): time.sleep(0.1) self.reset_history(self._chat_name) - def reset_history(self, name=None): + def reset_history(self, name: str = None): if not self.chat_inited: st.session_state[self._session_key] = {} - name = name or self._chat_name - st.session_state[self._session_key][name] = [] + name = name or self.cur_chat_name + st.session_state[self._session_key][name] = {"history": [], "context": AttrDict()} if self._greetings: - st.session_state[self._session_key][name] = [{ + st.session_state[self._session_key][name]["history"] = [{ "role": "assistant", "elements": self._greetings, "metadata": {}, }] - def use_chat_name(self, name: str ="default") -> None: + def use_chat_name(self, name: str = "default") -> None: self.init_session() self._chat_name = name if name not in st.session_state[self._session_key]: self.reset_history(name) + def change_chat_name(self, new_name: str, origin_name: str = None) -> bool: + self.init_session() + origin_name = origin_name or self.cur_chat_name + if (origin_name in st.session_state[self._session_key] + and new_name not in st.session_state[self._session_key]): + st.session_state[self._session_key][new_name] = st.session_state[self._session_key].pop(origin_name) + self._chat_name = new_name + def del_chat_name(self, name: str): self.init_session() if name in st.session_state[self._session_key]: @@ -88,15 +113,48 @@ def get_chat_names(self): def cur_chat_name(self): return self._chat_name + @property + def context(self) -> AttrDict: + self.init_session() + return st.session_state[self._session_key].get(self._chat_name, {}).get("context", AttrDict()) + @property def history(self) -> List: self.init_session() - return st.session_state[self._session_key].get(self._chat_name, []) + return st.session_state[self._session_key].get(self._chat_name, {}).get("history", []) def other_history(self, chat_name: str, default: List=[]) -> Optional[List]: self.init_session() chat_name = chat_name or self.cur_chat_name - return st.session_state[self._session_key].get(chat_name, default) + return st.session_state[self._session_key].get(chat_name, {}).get("history", default) + + def other_context(self, chat_name: str, default: AttrDict=AttrDict()) -> AttrDict: + self.init_session() + chat_name = chat_name or self.cur_chat_name + return st.session_state[self._session_key].get(chat_name, {}).get("context", default) + + def context_to_session(self, chat_name: str=None, include: List[str]=[], exclude: List[str]=[]) -> None: + ''' + copy context to st.session_state. + copy named variables only if `kw` specified + this can be usefull to restore session_state when you switch between chat conversations + ''' + for k, v in self.other_context(chat_name).items(): + if (not include or k in include + and k not in exclude): + st.session_state[k] = v + + def context_from_session(self, chat_name: str=None, include: List[str]=[], exclude: List[str]=[]) -> None: + ''' + copy context from st.session_state. + copy named variables only if `kw` specified + this can be usefull to save session_state when you switch between chat conversations + ''' + for k, v in st.session_state.items(): + if ((not include or k in include) + and k not in exclude + and k != self._session_key): + self.other_context(chat_name)[k] = v def filter_history( self, @@ -194,7 +252,7 @@ def to_dict( self, ) -> Dict: ''' - export current state to dict + export current state including messages and context to dict ''' self.init_session() @@ -208,7 +266,7 @@ def p(val): else: return val - histories = {x: p(self.other_history(x)) for x in self.get_chat_names()} + histories = {x: p({"history": self.other_history(x), "context": self.other_context(x)}) for x in self.get_chat_names()} return { "cur_chat_name": self.cur_chat_name, "session_key": self._session_key, @@ -246,11 +304,12 @@ def from_dict( self.reset_history(name) for h in history: msg = { - "role": h["role"], - "elements": [OutputElement.from_dict(y) for y in h["elements"]], - "metadata": h["metadata"], + "role": h["history"]["role"], + "elements": [OutputElement.from_dict(y) for y in h["history"]["elements"]], + "metadata": h["history"]["metadata"], } self.other_history(name).append(msg) + self.other_context(name).update(h["context"]) self.use_chat_name(data["cur_chat_name"]) return self @@ -348,7 +407,7 @@ def update_msg( expanded: bool = None, state: bool = None, metadata: Dict = {}, - ) -> st._DeltaGenerator: + ) -> DeltaGenerator: self.init_session() if not self.history or not self.history[history_index]["elements"]: return