diff --git a/README.md b/README.md index 6ea69a9..7c1fdca 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ This package(>=1.0.0) will focus on wrapper of official chat elements to make ch A Streamlit component to show chat messages. It's basiclly a wrapper of streamlit officeial elements including the chat elemnts. -![demo](https://github.com/liunux4odoo/streamlit-chatbox/blob/master/demo.gif?raw=true) +![](demo.gif) +![](demo_agent.gif) ## Features @@ -83,10 +84,10 @@ if query := st.chat_input('input your question here'): text = "" for x, docs in generator: text += x - chat_box.update_msg(text, 0, streaming=True) - chat_box.update_msg("\n\n".join(docs), 1, streaming=False) + chat_box.update_msg(text, element_index=0, streaming=True) # update the element without focus - chat_box.update_msg(text, 0, streaming=False) + chat_box.update_msg(text, element_index=0, streaming=False, state="complete") + chat_box.update_msg("\n\n".join(docs), element_index=1, streaming=False, state="complete") else: text, docs = llm.chat(query) chat_box.ai_say( @@ -98,7 +99,8 @@ if query := st.chat_input('input your question here'): ] ) -if st.button('show me the multimedia'): +cols = st.columns(2) +if cols[0].button('show me the multimedia'): chat_box.ai_say(Image( 'https://tse4-mm.cn.bing.net/th/id/OIP-C.cy76ifbr2oQPMEs2H82D-QHaEv?w=284&h=181&c=7&r=0&o=5&dpr=1.5&pid=1.7')) time.sleep(0.5) @@ -108,6 +110,29 @@ if st.button('show me the multimedia'): chat_box.ai_say( Audio('https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_1mb.mp4')) +if cols[1].button('run agent'): + chat_box.user_say('run agent') + agent = FakeAgent() + text = "" + + if streaming: + # streaming: + chat_box.ai_say() # generate a blank placeholder to render messages + for d in agent.run_stream(): + if d["type"] == "complete": + chat_box.update_msg(expanded=False, state="complete") + chat_box.insert_msg(d["llm_output"]) + break + + if d["status"] == 1: + chat_box.update_msg(expanded=False, state="complete") + text = "" + chat_box.insert_msg(Markdown(text, title=d["text"], in_expander=True, expanded=True)) + elif d["status"] == 2: + text += d["llm_output"] + chat_box.update_msg(text, streaming=True) + else: + chat_box.update_msg(text, streaming=False) btns.download_button( "Export Markdown", @@ -130,7 +155,6 @@ if btns.button("clear history"): if show_history: st.write(chat_box.history) - ``` ## Todos @@ -160,3 +184,5 @@ if show_history: - [x] export to markdown - [x] export to json - [x] import json + +- [x] support output of langchain' Agent. diff --git a/demo_agent.gif b/demo_agent.gif new file mode 100644 index 0000000..19844aa Binary files /dev/null and b/demo_agent.gif differ diff --git a/example.py b/example.py index caa57e5..7ffe8a4 100644 --- a/example.py +++ b/example.py @@ -47,10 +47,10 @@ text = "" for x, docs in generator: text += x - chat_box.update_msg(text, 0, streaming=True) - chat_box.update_msg("\n\n".join(docs), 1, streaming=False) + chat_box.update_msg(text, element_index=0, streaming=True) # update the element without focus - chat_box.update_msg(text, 0, streaming=False) + chat_box.update_msg(text, element_index=0, streaming=False, state="complete") + chat_box.update_msg("\n\n".join(docs), element_index=1, streaming=False, state="complete") else: text, docs = llm.chat(query) chat_box.ai_say( @@ -62,7 +62,8 @@ ] ) -if st.button('show me the multimedia'): +cols = st.columns(2) +if cols[0].button('show me the multimedia'): chat_box.ai_say(Image( 'https://tse4-mm.cn.bing.net/th/id/OIP-C.cy76ifbr2oQPMEs2H82D-QHaEv?w=284&h=181&c=7&r=0&o=5&dpr=1.5&pid=1.7')) time.sleep(0.5) @@ -72,6 +73,28 @@ chat_box.ai_say( Audio('https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_1mb.mp4')) +if cols[1].button('run agent'): + chat_box.user_say('run agent') + agent = FakeAgent() + text = "" + + # streaming: + chat_box.ai_say() # generate a blank placeholder to render messages + for d in agent.run_stream(): + if d["type"] == "complete": + chat_box.update_msg(expanded=False, state="complete") + chat_box.insert_msg(d["llm_output"]) + break + + if d["status"] == 1: + chat_box.update_msg(expanded=False, state="complete") + text = "" + chat_box.insert_msg(Markdown(text, title=d["text"], in_expander=True, expanded=True)) + elif d["status"] == 2: + text += d["llm_output"] + chat_box.update_msg(text, streaming=True) + else: + chat_box.update_msg(text, streaming=False) btns.download_button( "Export Markdown", diff --git a/setup.py b/setup.py index a0532b2..c44755a 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ def readme(): setuptools.setup( name='streamlit-chatbox', - version='1.1.7', + version='1.1.8', author='liunux', author_email='liunux@qq.com', description='A chat box and some helpful tools used to build chatbot app with streamlit', @@ -21,7 +21,7 @@ def readme(): classifiers=[], python_requires='>=3.8', install_requires=[ - 'streamlit>=1.24.0', + 'streamlit>=1.26.0', 'simplejson', ] ) diff --git a/streamlit_chatbox/__init__.py b/streamlit_chatbox/__init__.py index 069ac14..1c26c67 100644 --- a/streamlit_chatbox/__init__.py +++ b/streamlit_chatbox/__init__.py @@ -14,6 +14,7 @@ "Video", "OutputElement", "FakeLLM", + "FakeAgent", ] diff --git a/streamlit_chatbox/elements.py b/streamlit_chatbox/elements.py index 3e781c2..213ac81 100644 --- a/streamlit_chatbox/elements.py +++ b/streamlit_chatbox/elements.py @@ -1,6 +1,6 @@ from typing import * import streamlit as st -from pydantic import BaseModel, Field +# from pydantic import BaseModel, Field class Element: @@ -9,12 +9,13 @@ class Element: ''' def __init__(self, + *, output_method: str = "markdown", - *args: Any, + metadata: Dict = {}, **kwargs: Any, ) -> None: self._output_method = output_method - self._args = args + self._metadata = metadata self._kwargs = kwargs self._defualt_kwargs = { "markdown": { @@ -39,15 +40,28 @@ def _set_default_kwargs(self) -> None: for k, v in default.items(): self._kwargs.setdefault(k, v) - def __call__(self) -> st._DeltaGenerator: + def __call__(self, render_to: st._DeltaGenerator=None) -> st._DeltaGenerator: # assert self._dg is None, "Every element can be rendered once only." - self._place_holder = st.empty() + render_to = render_to or st + self._place_holder = render_to.empty() output_method = getattr(self._place_holder, self._output_method) assert callable( output_method), f"The attribute st.{self._output_mehtod} is not callable." self._dg = output_method(*self._args, **self._kwargs) return self._dg + @property + def dg(self) -> st._DeltaGenerator: + return self._dg + + @property + def place_holder(self) -> st._DeltaGenerator: + return self._place_holder + + @property + def metadata(self) -> Dict: + return self._metadata + class OutputElement(Element): def __init__(self, @@ -56,6 +70,7 @@ def __init__(self, title: str = "", in_expander: bool = False, expanded: bool = False, + state: Literal["running", "complete", "error"] = "running", **kwargs: Any, ) -> None: super().__init__(output_method=output_method, **kwargs) @@ -63,18 +78,23 @@ def __init__(self, self._title = title self._in_expander = in_expander self._expanded = expanded + self._state = state + self._attrs = ["_content", "_output_method", "_kwargs", "_metadata", + "_title", "_in_expander", "_expanded", "_state",] - def __call__(self) -> st._DeltaGenerator: - self._args = (self._content,) - if self._in_expander: - with st.expander(self._title, self._expanded): - return super().__call__() - else: - return super().__call__() + def clone(self) -> "OutputElement": + obj = type(self)() + for n in self._attrs: + setattr(obj, n, getattr(self, n)) + return obj + + @property + def content(self) -> Union[str, bytes]: + return self._content def __repr__(self) -> str: method = self._output_method.capitalize() - return f"{method} Element:\n{self._content}" + return f"{method} Element:\n{self.content}" def to_dict(self) -> Dict: return { @@ -83,6 +103,8 @@ def to_dict(self) -> Dict: "title": self._title, "in_expander": self._in_expander, "expanded": self._expanded, + "state": self._state, + "metadata": self._metadata, "kwargs": self._kwargs, } @@ -101,6 +123,8 @@ def from_dict(cls, d: Dict) -> "OutputElement": title=d.get("title"), in_expander=d.get("in_expander"), expanded=d.get("expanded"), + state=d.get("state"), + metadata=d.get("metadata", {}), **d.get("kwargs", {}), ) if factory_cls is cls: @@ -108,18 +132,52 @@ def from_dict(cls, d: Dict) -> "OutputElement": return factory_cls(**kwargs) + def __call__(self, render_to: st._DeltaGenerator=None, direct: bool=False) -> st._DeltaGenerator: + if render_to is None: + if self._place_holder is None: + self._place_holder = st.empty() + else: + if direct: + self._place_holder = render_to + else: + self._place_holder = render_to.empty() + temp_dg = self._place_holder + + if self._in_expander: + temp_dg = self._place_holder.status(self._title, expanded=self._expanded, state=self._state) + output_method = getattr(temp_dg, self._output_method) + assert callable( + output_method), f"The attribute st.{self._output_mehtod} is not callable." + self._dg = output_method(self._content, **self._kwargs) + return self._dg + def update_element( self, - element: "OutputElement", - streaming: Optional[bool] = None, + element: "OutputElement" = None, + *, + title: str = None, + expanded: bool = None, + state: bool = None, ) -> st._DeltaGenerator: - assert self._place_holder is not None, "You must render the element before setting new element." - with self._place_holder: - self._dg = element() + assert self.place_holder is not None, f"You must render the element {self} before setting new element." + attrs = {} + if title is not None: + attrs["_title"] = title + if expanded is not None: + attrs["_expanded"] = expanded + if state is not None: + attrs["_state"] = state + + if element is None: + element = self + for k, v in attrs.items(): + setattr(element, k, v) + + element(self.place_holder, direct=True) return self._dg - def attrs_from(self, target): - for attr in ["_in_expander", "_expanded", "_title"]: + def status_from(self, target): + for attr in ["_in_expander", "_expanded", "_title", "_state"]: setattr(self, attr, getattr(target, attr)) @@ -128,24 +186,60 @@ class InputElement(Element): class Markdown(OutputElement): - def __init__(self, content: Union[str, bytes] = "", title: str = "", in_expander: bool = False, expanded: bool = False, **kwargs: Any) -> None: + def __init__( + self, + content: Union[str, bytes] = "", + title: str = "", + in_expander: bool = False, + expanded: bool = False, + state: Literal["running", "complete", "error"] = "running", + **kwargs: Any, + ) -> None: super().__init__(content, output_method="markdown", title=title, - in_expander=in_expander, expanded=expanded, **kwargs) + in_expander=in_expander, expanded=expanded, + state=state, **kwargs) class Image(OutputElement): - def __init__(self, content: Union[str, bytes] = "", title: str = "", in_expander: bool = False, expanded: bool = False, **kwargs: Any) -> None: + def __init__( + self, + content: Union[str, bytes] = "", + title: str = "", + in_expander: bool = False, + expanded: bool = False, + state: Literal["running", "complete", "error"] = "running", + **kwargs: Any, + ) -> None: super().__init__(content, output_method="image", title=title, - in_expander=in_expander, expanded=expanded, **kwargs) + in_expander=in_expander, expanded=expanded, + state=state, **kwargs) class Audio(OutputElement): - def __init__(self, content: Union[str, bytes] = "", title: str = "", in_expander: bool = False, expanded: bool = False, **kwargs: Any) -> None: + def __init__( + self, + content: Union[str, bytes] = "", + title: str = "", + in_expander: bool = False, + expanded: bool = False, + state: Literal["running", "complete", "error"] = "running", + **kwargs: Any, + ) -> None: super().__init__(content, output_method="audio", title=title, - in_expander=in_expander, expanded=expanded, **kwargs) + in_expander=in_expander, expanded=expanded, + state=state, **kwargs) class Video(OutputElement): - def __init__(self, content: Union[str, bytes] = "", title: str = "", in_expander: bool = False, expanded: bool = False, **kwargs: Any) -> None: + def __init__( + self, + content: Union[str, bytes] = "", + title: str = "", + in_expander: bool = False, + expanded: bool = False, + state: Literal["running", "complete", "error"] = "running", + **kwargs: Any, + ) -> None: super().__init__(content, output_method="video", title=title, - in_expander=in_expander, expanded=expanded, **kwargs) + in_expander=in_expander, expanded=expanded, + state=state, **kwargs) diff --git a/streamlit_chatbox/messages.py b/streamlit_chatbox/messages.py index a696b77..0d35e16 100644 --- a/streamlit_chatbox/messages.py +++ b/streamlit_chatbox/messages.py @@ -4,16 +4,28 @@ import simplejson as json +# main concept: +# ChatBox is the top level object, repsents all history messages。 +# every message is feed to st.chat_message,including two objects: +# - role +# - elements: list of OutputElement to render. every element includes: +# - content and output_method, ie. st.output_method(content, **kwargs) +# - in_exapander decides the element is rendered directly or in st.status +# - if directly: element is rendered in a st.empty in the st.container +# - if in expander: element is rendered in a st.empty in the st.status + + class ChatBox: def __init__( self, chat_name: str = "default", - session_key: str = "messages", + session_key: str = "chat_history", user_avatar: str = "user", assistant_avatar: str = "assistant", greetings: Union[str, OutputElement, List[Union[str, OutputElement]]] = [], ) -> None: self._chat_name = chat_name + self._chat_containers = [] self._session_key = session_key self._user_avatar = user_avatar self._assistant_avatar = assistant_avatar @@ -27,38 +39,42 @@ def __init__( @property def chat_inited(self): - return self._session_key in st.session_state + return self._session_key in st.session_state.keys() def init_session(self, clear: bool =False): if not self.chat_inited or clear: st.session_state[self._session_key] = {} + time.sleep(0.1) self.reset_history(self._chat_name) def reset_history(self, name=None): - assert self.chat_inited, "please call init_session first" + if not self.chat_inited: + st.session_state[self._session_key] = {} + name = name or self._chat_name - st.session_state[self._session_key].update({name: []}) + st.session_state[self._session_key][name] = [] if self._greetings: st.session_state[self._session_key][name] = [{ "role": "assistant", "elements": self._greetings, + "metadata": {}, }] def use_chat_name(self, name: str ="default") -> None: - assert self.chat_inited, "please call init_session first" + self.init_session() self._chat_name = name if name not in st.session_state[self._session_key]: self.reset_history(name) def del_chat_name(self, name: str): - assert self.chat_inited, "please call init_session first" + self.init_session() if name in st.session_state[self._session_key]: msgs = st.session_state[self._session_key].pop(name) self._chat_name=self.get_chat_names()[0] return msgs def get_chat_names(self): - assert self.chat_inited, "please call init_session first" + self.init_session() return list(st.session_state[self._session_key].keys()) @property @@ -67,11 +83,11 @@ def cur_chat_name(self): @property def history(self) -> List: - assert self.chat_inited, "please call init_session first" + self.init_session() return st.session_state[self._session_key].get(self._chat_name, []) def other_history(self, chat_name: str, default: List=[]) -> Optional[List]: - assert self.chat_inited, "please call init_session first" + self.init_session() chat_name = chat_name or self.cur_chat_name return st.session_state[self._session_key].get(chat_name, default) @@ -87,13 +103,13 @@ def filter_history( filter: custom filter fucntion with arguments (msg,) or (msg, index), return None if skipping msg. default filter returns all text/markdown content. stop: custom function to stop filtering with arguments (history,) history is already filtered messages, return True if stop. default stop on history_len ''' - assert self.chat_inited, "please call init_session first" + self.init_session() def default_filter(msg, index=None): ''' filter text messages only with the format {"role":role, "content":content} ''' - content = [x._content for x in msg["elements"] if x._output_method in ["markdown", "text"]] + content = [x.content for x in msg["elements"] if x._output_method in ["markdown", "text"]] return { "role": msg["role"], "content": "\n\n".join(content), @@ -101,8 +117,8 @@ def default_filter(msg, index=None): def default_stop(history): if isinstance(history_len, int): - ai_count = len(x for x in history if x["role"] == "user") - return ai_count >= history_len + user_count = len(x for x in history if x["role"] == "user") + return user_count >= history_len else: return False @@ -122,15 +138,15 @@ def default_stop(history): filtered = filter(msg, i) if filtered is not None: result.insert(0, filtered) - if isinstance(history_len, int) and len(result) >= history_len: - break + + if stop(history): + break return result def export2md( self, chat_name: str = None, - filter: Callable = None, user_avatar: str = "User", ai_avatar: str = "AI", user_bg_color: str = "#DCFDC8", @@ -141,7 +157,7 @@ def export2md( default export messages as table of text. use callback(msg) to custom exported content. ''' - assert self.chat_inited, "please call init_session first" + self.init_session() lines = [ "\n" "| | |\n", @@ -156,7 +172,7 @@ def set_bg_color(text, bg_color): if callable(callback): line = callback(msg) else: - contents = [e._content for e in msg["elements"]] + contents = [e.content for e in msg["elements"]] if msg["role"] == "user": content = "

".join(set_bg_color(c, user_bg_color) for c in contents) avatar = set_bg_color(user_avatar, user_bg_color) @@ -173,7 +189,7 @@ def to_dict( ''' export current state to dict ''' - assert self.chat_inited, "please call init_session first" + self.init_session() def p(val): if isinstance(val, (list, tuple)): @@ -242,56 +258,67 @@ def _prepare_elements( elif isinstance(elements, list): elements = [Markdown(e) if isinstance( e, str) else e for e in elements] - return elements + return elements or [] def user_say( self, - elements: Union[OutputElement, str, List[Union[OutputElement, str]]], - to_history: bool = True, - not_render: bool = False, + elements: Union[OutputElement, str, List[Union[OutputElement, str]]] = None, + metadata: Dict = {}, ) -> List[OutputElement]: - assert self.chat_inited, "please call init_session first" + self.init_session() elements = self._prepare_elements(elements) - if not not_render: - with st.chat_message("user", avatar=self._user_avatar): - for element in elements: - element() - if to_history: - self.history.append({"role": "user", "elements": elements}) + + chat_ele = st.chat_message("user", avatar=self._user_avatar) + self._chat_containers.append(chat_ele) + for element in elements: + element(render_to=chat_ele) + + self.history.append({"role": "user", "elements": elements, "metadata": metadata}) return elements def ai_say( self, - elements: Union[OutputElement, str, List[Union[OutputElement, str]]], - to_history: bool = True, - not_render: bool = False, + elements: Union[OutputElement, str, List[Union[OutputElement, str]]] = None, + metadata: Dict = {}, ) -> List[OutputElement]: - assert self.chat_inited, "please call init_session first" + self.init_session() elements = self._prepare_elements(elements) - if not not_render: - with st.chat_message("assistant", avatar=self._assistant_avatar): - for element in elements: - element() - if to_history: - self.history.append({"role": "assistant", "elements": elements}) + + chat_ele = st.chat_message("assistant", avatar=self._assistant_avatar) + container = chat_ele.container() + self._chat_containers.append(container) + for element in elements: + element(render_to=container) + + self.history.append({"role": "assistant", "elements": elements, "metadata": metadata}) return elements def output_messages(self): - assert self.chat_inited, "please call init_session first" + self.init_session() + self._chat_containers = [] for msg in self.history: avatar = self._user_avatar if msg["role"] == "user" else self._assistant_avatar - with st.chat_message(msg["role"], avatar=avatar): - for element in msg["elements"]: - element() + chat_ele = st.chat_message(msg["role"], avatar=avatar) + container = chat_ele.container() + self._chat_containers.append(container) + for element in msg["elements"]: + element(render_to=container) def update_msg( self, - element: Union["OutputElement", str], + element: Union["OutputElement", str] = None, + *, element_index: int = -1, history_index: int = -1, streaming: Optional[bool] = None, + title: str = None, + expanded: bool = None, + state: bool = None, ) -> st._DeltaGenerator: - assert self.chat_inited, "please call init_session first" + self.init_session() + if not self.history or not self.history[history_index]["elements"]: + return + if isinstance(element, str): element = Markdown(element) if streaming is None: @@ -299,18 +326,41 @@ def update_msg( if streaming and isinstance(element, Markdown): element._content += " ▌" - old_element = self.history[history_index]["elements"][element_index] + old_element: OutputElement = self.history[history_index]["elements"][element_index] + if element is not None: + element.status_from(old_element) + self.history[history_index]["elements"][element_index] = element + dg = old_element.update_element( - element, streaming + element, + title=title, + expanded=expanded, + state=state, ) - element.attrs_from(old_element) - self.history[history_index]["elements"][element_index] = element return dg + def insert_msg( + self, + element: Union["OutputElement", str], + *, + history_index: int = -1, + pos: int = -1, + ) -> OutputElement: + self.init_session() + if isinstance(element, str): + element = Markdown(element) + elements = self.history[history_index]["elements"] + if pos < 0: + pos += len(elements) + 1 + elements.insert(pos, element) + + element(render_to=self._chat_containers[history_index]) + return element + class FakeLLM: def _answer(self, query: str) -> str: - answer = f"this is my answer for your question:\n\n{query}" + answer = f"this is llm answer for your question:\n\n{query}" docs = ["reference 1", "reference 2", "reference 3"] return answer, docs @@ -322,3 +372,104 @@ def chat_stream(self, query: str): for t in text: yield t, docs time.sleep(0.1) + + +class FakeAgent: + llm = FakeLLM() + tools = ["search", "math"] + + def thought(self, msg): + return f"thought {msg}" + + def action(self, msg): + return f"action {msg}" + + def run(self, query: str = "", steps: int = 2): + result = [] + for i in range(1, steps + 1): + thought = self.thought(i) + result.append({ + "type": "thought", + "id": i, + "text": thought, + "llm_output": self.llm.chat(thought)[0] + }) + + action = self.action(i) + result.append({ + "type": "action", + "id": i, + "text": action, + "llm_output": self.llm.chat(action)[0] + }) + + result.append({ + "type": "complete", + "llm_output": "final answer" + }) + + return result + + def run_stream(self, query: str = "", steps: int = 2): + for i in range(1, steps + 1): + thought = self.thought(i) + + yield { + "type": "thought", + "id": i, + "text": thought, + "status": 1, + "llm_output": "", + } + + for chunk, _ in self.llm.chat_stream(thought): + d = { + "type": "thought", + "id": i, + "text": thought, + "status": 2, + "llm_output": chunk, + } + print(d) + yield d + + yield { + "type": "thought", + "id": i, + "text": thought, + "status": 3, + "llm_output": "", + } + + action = self.action(i) + yield { + "type": "action", + "id": i, + "text": action, + "status": 1, + "llm_output": "", + } + + for chunk, _ in self.llm.chat_stream(action): + d = { + "type": "action", + "id": i, + "text": action, + "status": 2, + "llm_output": chunk, + } + print(d) + yield d + + yield { + "type": "action", + "id": i, + "text": action, + "status": 3, + "llm_output": "", + } + + yield { + "type": "complete", + "llm_output": "final answer" + }