Skip to content

Commit

Permalink
- fix type hint error with streamlit >= 1.33.0 (#8)
Browse files Browse the repository at this point in the history
- add ChatBox.change_chat_name to rename a chat conversation
- maintain a context bound to chat conversation, it is like to be a sub session_state for every chat, context will get changed when you switch chat names.
      - user can save chat bound values by `ChatBox.context['x'] = 1`
      - values of widgets specified with a key can be saved to chat context with `ChatBox.context_from_session` and restored to st.session_state by `ChatBox.context_to_session`
  • Loading branch information
liunux4odoo committed Apr 22, 2024
1 parent 6179c41 commit 62bfbad
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 28 deletions.
50 changes: 46 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ It's basiclly a wrapper of streamlit officeial elements including the chat elemn
- support streaming output.
- support markdown/image/video/audio messages, and all streamlit elements could be supported by customized `OutputElement`.
- output multiple messages at once, and make them collapsable.
- maintain session state context bound to chat conversation
- export & import chat histories

This make it easy to chat with langchain LLMs in streamlit.
Expand All @@ -46,13 +47,21 @@ import simplejson as json

llm = FakeLLM()
chat_box = ChatBox()
chat_box.use_chat_name("chat1") # add a chat conversatoin

def on_chat_change():
chat_box.use_chat_name(st.session_state["chat_name"])
chat_box.context_to_session() # restore widget values to st.session_state when chat name changed


with st.sidebar:
st.subheader('start to chat using streamlit')
streaming = st.checkbox('streaming', True)
in_expander = st.checkbox('show messages in expander', True)
show_history = st.checkbox('show history', False)
chat_name = st.selectbox("Chat Session:", ["default", "chat1"], key="chat_name", on_change=on_chat_change)
chat_box.use_chat_name(chat_name)
streaming = st.checkbox('streaming', key="streaming")
in_expander = st.checkbox('show messages in expander', key="in_expander")
show_history = st.checkbox('show session state', key="show_history")
chat_box.context_from_session(exclude=["chat_name"]) # save widget values to chat context

st.divider()

Expand All @@ -71,6 +80,22 @@ with st.sidebar:
chat_box.init_session()
chat_box.output_messages()

def on_feedback(
feedback,
chat_history_id: str = "",
history_index: int = -1,
):
reason = feedback["text"]
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) # convert emoji to integer
# do something
st.session_state["need_rerun"] = True


feedback_kwargs = {
"feedback_type": "thumbs",
"optional_text_label": "wellcome to feedback",
}

if query := st.chat_input('input your question here'):
chat_box.user_say(query)
if streaming:
Expand All @@ -91,6 +116,11 @@ if query := st.chat_input('input your question here'):
# update the element without focus
chat_box.update_msg(text, element_index=0, streaming=False, state="complete")
chat_box.update_msg("\n\n".join(docs), element_index=1, streaming=False, state="complete")
chat_history_id = "some id"
chat_box.show_feedback(**feedback_kwargs,
key=chat_history_id,
on_submit=on_feedback,
kwargs={"chat_history_id": chat_history_id, "history_index": len(chat_box.history) - 1})
else:
text, docs = llm.chat(query)
chat_box.ai_say(
Expand Down Expand Up @@ -156,7 +186,8 @@ if btns.button("clear history"):


if show_history:
st.write(chat_box.history)
st.write(st.session_state)

```

## Todos
Expand Down Expand Up @@ -189,3 +220,14 @@ if show_history:
- [x] import json

- [x] support output of langchain' Agent.
- [x] conext bound to chat

# changelog

## v1.1.12
- fix type hint error with streamlit >= 1.33.0 (#8)
- add ChatBox.change_chat_name to rename a chat conversation
- maintain a context bound to chat conversation, it is like to be a sub session_state for every chat, context will get changed when you switch chat names.
- user can save chat bound values by `ChatBox.context['x'] = 1`
- values of widgets specified with a key can be saved to chat context with `ChatBox.context_from_session` and restored to st.session_state by `ChatBox.context_to_session`

18 changes: 13 additions & 5 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@

llm = FakeLLM()
chat_box = ChatBox()
chat_box.use_chat_name("chat1") # add a chat conversatoin

def on_chat_change():
chat_box.use_chat_name(st.session_state["chat_name"])
chat_box.context_to_session() # restore widget values to st.session_state when chat name changed


with st.sidebar:
st.subheader('start to chat using streamlit')
streaming = st.checkbox('streaming', True)
in_expander = st.checkbox('show messages in expander', True)
show_history = st.checkbox('show history', False)
chat_name = st.selectbox("Chat Session:", ["default", "chat1"], key="chat_name", on_change=on_chat_change)
chat_box.use_chat_name(chat_name)
streaming = st.checkbox('streaming', key="streaming")
in_expander = st.checkbox('show messages in expander', key="in_expander")
show_history = st.checkbox('show session state', key="show_history")
chat_box.context_from_session(exclude=["chat_name"]) # save widget values to chat context

st.divider()

Expand Down Expand Up @@ -44,7 +52,7 @@ def on_feedback(

feedback_kwargs = {
"feedback_type": "thumbs",
"optional_text_label": "欢迎反馈您打分的理由",
"optional_text_label": "wellcome to feedback",
}

if query := st.chat_input('input your question here'):
Expand Down Expand Up @@ -137,4 +145,4 @@ def on_feedback(


if show_history:
st.write(chat_box.history)
st.write(st.session_state)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def readme():

setuptools.setup(
name='streamlit-chatbox',
version='1.1.11',
version='1.1.12',
author='liunux',
author_email='[email protected]',
description='A chat box and some helpful tools used to build chatbot app with streamlit',
Expand Down
11 changes: 6 additions & 5 deletions streamlit_chatbox/elements.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import *
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
# from pydantic import BaseModel, Field


Expand Down Expand Up @@ -40,7 +41,7 @@ def _set_default_kwargs(self) -> None:
for k, v in default.items():
self._kwargs.setdefault(k, v)

def __call__(self, render_to: st._DeltaGenerator=None) -> st._DeltaGenerator:
def __call__(self, render_to: DeltaGenerator=None) -> DeltaGenerator:
# assert self._dg is None, "Every element can be rendered once only."
render_to = render_to or st
self._place_holder = render_to.empty()
Expand All @@ -51,11 +52,11 @@ def __call__(self, render_to: st._DeltaGenerator=None) -> st._DeltaGenerator:
return self._dg

@property
def dg(self) -> st._DeltaGenerator:
def dg(self) -> DeltaGenerator:
return self._dg

@property
def place_holder(self) -> st._DeltaGenerator:
def place_holder(self) -> DeltaGenerator:
return self._place_holder

@property
Expand Down Expand Up @@ -132,7 +133,7 @@ def from_dict(cls, d: Dict) -> "OutputElement":

return factory_cls(**kwargs)

def __call__(self, render_to: st._DeltaGenerator=None, direct: bool=False) -> st._DeltaGenerator:
def __call__(self, render_to: DeltaGenerator=None, direct: bool=False) -> DeltaGenerator:
if render_to is None:
if self._place_holder is None:
self._place_holder = st.empty()
Expand All @@ -158,7 +159,7 @@ def update_element(
title: str = None,
expanded: bool = None,
state: bool = None,
) -> st._DeltaGenerator:
) -> DeltaGenerator:
assert self.place_holder is not None, f"You must render the element {self} before setting new element."
attrs = {}
if title is not None:
Expand Down
85 changes: 72 additions & 13 deletions streamlit_chatbox/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@
}


class AttrDict(dict):
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(key)

def __setattr__(self, key: str, value: Any) -> None:
self[key] = value

def __delattr__(self, key: str) -> None:
try:
del self[key]
except KeyError:
raise AttributeError(key)


class ChatBox:
def __init__(
self,
Expand Down Expand Up @@ -54,25 +71,33 @@ def init_session(self, clear: bool =False):
time.sleep(0.1)
self.reset_history(self._chat_name)

def reset_history(self, name=None):
def reset_history(self, name: str = None):
if not self.chat_inited:
st.session_state[self._session_key] = {}

name = name or self._chat_name
st.session_state[self._session_key][name] = []
name = name or self.cur_chat_name
st.session_state[self._session_key][name] = {"history": [], "context": AttrDict()}
if self._greetings:
st.session_state[self._session_key][name] = [{
st.session_state[self._session_key][name]["history"] = [{
"role": "assistant",
"elements": self._greetings,
"metadata": {},
}]

def use_chat_name(self, name: str ="default") -> None:
def use_chat_name(self, name: str = "default") -> None:
self.init_session()
self._chat_name = name
if name not in st.session_state[self._session_key]:
self.reset_history(name)

def change_chat_name(self, new_name: str, origin_name: str = None) -> bool:
self.init_session()
origin_name = origin_name or self.cur_chat_name
if (origin_name in st.session_state[self._session_key]
and new_name not in st.session_state[self._session_key]):
st.session_state[self._session_key][new_name] = st.session_state[self._session_key].pop(origin_name)
self._chat_name = new_name

def del_chat_name(self, name: str):
self.init_session()
if name in st.session_state[self._session_key]:
Expand All @@ -88,15 +113,48 @@ def get_chat_names(self):
def cur_chat_name(self):
return self._chat_name

@property
def context(self) -> AttrDict:
self.init_session()
return st.session_state[self._session_key].get(self._chat_name, {}).get("context", AttrDict())

@property
def history(self) -> List:
self.init_session()
return st.session_state[self._session_key].get(self._chat_name, [])
return st.session_state[self._session_key].get(self._chat_name, {}).get("history", [])

def other_history(self, chat_name: str, default: List=[]) -> Optional[List]:
self.init_session()
chat_name = chat_name or self.cur_chat_name
return st.session_state[self._session_key].get(chat_name, default)
return st.session_state[self._session_key].get(chat_name, {}).get("history", default)

def other_context(self, chat_name: str, default: AttrDict=AttrDict()) -> AttrDict:
self.init_session()
chat_name = chat_name or self.cur_chat_name
return st.session_state[self._session_key].get(chat_name, {}).get("context", default)

def context_to_session(self, chat_name: str=None, include: List[str]=[], exclude: List[str]=[]) -> None:
'''
copy context to st.session_state.
copy named variables only if `kw` specified
this can be usefull to restore session_state when you switch between chat conversations
'''
for k, v in self.other_context(chat_name).items():
if (not include or k in include
and k not in exclude):
st.session_state[k] = v

def context_from_session(self, chat_name: str=None, include: List[str]=[], exclude: List[str]=[]) -> None:
'''
copy context from st.session_state.
copy named variables only if `kw` specified
this can be usefull to save session_state when you switch between chat conversations
'''
for k, v in st.session_state.items():
if ((not include or k in include)
and k not in exclude
and k != self._session_key):
self.other_context(chat_name)[k] = v

def filter_history(
self,
Expand Down Expand Up @@ -194,7 +252,7 @@ def to_dict(
self,
) -> Dict:
'''
export current state to dict
export current state including messages and context to dict
'''
self.init_session()

Expand All @@ -208,7 +266,7 @@ def p(val):
else:
return val

histories = {x: p(self.other_history(x)) for x in self.get_chat_names()}
histories = {x: p({"history": self.other_history(x), "context": self.other_context(x)}) for x in self.get_chat_names()}
return {
"cur_chat_name": self.cur_chat_name,
"session_key": self._session_key,
Expand Down Expand Up @@ -246,11 +304,12 @@ def from_dict(
self.reset_history(name)
for h in history:
msg = {
"role": h["role"],
"elements": [OutputElement.from_dict(y) for y in h["elements"]],
"metadata": h["metadata"],
"role": h["history"]["role"],
"elements": [OutputElement.from_dict(y) for y in h["history"]["elements"]],
"metadata": h["history"]["metadata"],
}
self.other_history(name).append(msg)
self.other_context(name).update(h["context"])

self.use_chat_name(data["cur_chat_name"])
return self
Expand Down Expand Up @@ -348,7 +407,7 @@ def update_msg(
expanded: bool = None,
state: bool = None,
metadata: Dict = {},
) -> st._DeltaGenerator:
) -> DeltaGenerator:
self.init_session()
if not self.history or not self.history[history_index]["elements"]:
return
Expand Down

0 comments on commit 62bfbad

Please sign in to comment.