diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..db8f53e --- /dev/null +++ b/.env.example @@ -0,0 +1,2 @@ +AIM_LOGS="" +OPENAI_API_KEY="" \ No newline at end of file diff --git a/agents/l4m_agent.py b/agents/l4m_agent.py index 302da89..04b29d9 100644 --- a/agents/l4m_agent.py +++ b/agents/l4m_agent.py @@ -12,8 +12,8 @@ def base_agent( llm: LLM object tools: List of tools to use by the agent """ - # chat_history = MessagesPlaceholder(variable_name="chat_history") - # memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) + chat_history = MessagesPlaceholder(variable_name="chat_history") + memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) agent = initialize_agent( llm=llm, tools=tools, @@ -21,11 +21,37 @@ def base_agent( max_iterations=5, early_stopping_method="generate", verbose=True, - # memory=memory, - # agent_kwargs={ - # "memory_prompts": [chat_history], - # "input_variables": ["input", "agent_scratchpad", "chat_history"], - # }, + memory=memory, + agent_kwargs={ + "memory_prompts": [chat_history], + "input_variables": ["input", "agent_scratchpad", "chat_history"], + }, ) print("agent initialized") return agent + + +def openai_function_agent(llm, tools, agent_type=AgentType.OPENAI_FUNCTIONS): + """OpenAI function agent that is fine-tuned to call functions with valid arguments. + + llm: LLM object + tools: List of tools to use by the agent + """ + agent_kwargs = { + "extra_prompt_messages": [MessagesPlaceholder(variable_name="memory")], + } + memory = ConversationBufferMemory(memory_key="memory", return_messages=True) + + agent = initialize_agent( + tools=tools, + llm=llm, + agent=agent_type, + max_iterations=5, + early_stopping_method="generate", + verbose=True, + # TODO: Fix this, cannot handle dataframes or geojsons as memory + # agent_kwargs=agent_kwargs, + # memory=memory, + ) + print("OpenAI function agent initialized") + return agent diff --git a/app.py b/app.py index cf916ae..ebabca1 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,5 @@ import os +from dotenv import load_dotenv import rasterio as rio import folium @@ -17,20 +18,23 @@ from tools.mercantile_tool import MercantileTool from tools.geopy.geocode import GeopyGeocodeTool -from tools.geopy.distance import GeopyDistanceTool + +# from tools.geopy.distance import GeopyDistanceTool from tools.osmnx.geometry import OSMnxGeometryTool from tools.osmnx.network import OSMnxNetworkTool from tools.stac.search import STACSearchTool -from agents.l4m_agent import base_agent +from agents.l4m_agent import base_agent, openai_function_agent -# DEBUG +# # DEBUG langchain.debug = True +# langchain.verbose = True + +# Load environment variables +load_dotenv() @st.cache_resource(ttl="1h") -def get_agent( - openai_api_key, agent_type=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION -): +def get_agent(openai_api_key): llm = ChatOpenAI( temperature=0, openai_api_key=openai_api_key, @@ -44,7 +48,7 @@ def get_agent( func=DuckDuckGoSearchRun().run, ) geocode_tool = GeopyGeocodeTool() - distance_tool = GeopyDistanceTool() + # distance_tool = GeopyDistanceTool() mercantile_tool = MercantileTool() geometry_tool = OSMnxGeometryTool() network_tool = OSMnxNetworkTool() @@ -53,44 +57,44 @@ def get_agent( tools = [ duckduckgo_tool, geocode_tool, - distance_tool, + # distance_tool, mercantile_tool, geometry_tool, network_tool, search_tool, ] - agent = base_agent(llm, tools, agent_type=agent_type) + agent = openai_function_agent(llm, tools) + # agent = base_agent(llm, tools) return agent -def run_query(agent, query): - return response - - def plot_raster(items): st.subheader("Preview of the first item sorted by cloud cover") selected_item = min(items, key=lambda item: item.properties["eo:cloud_cover"]) href = selected_item.assets["rendered_preview"].href - # arr = rio.open(href).read() + arr = rio.open(href).read() - # m = folium.Map(location=[28.6, 77.7], zoom_start=6) - - # img = folium.raster_layers.ImageOverlay( - # name="Sentinel 2", - # image=arr.transpose(1, 2, 0), - # bounds=selected_item.bbox, - # opacity=0.9, - # interactive=True, - # cross_origin=False, - # zindex=1, - # ) + lon_min, lat_min, lon_max, lat_max = selected_item.bbox + bbox = [[lat_min, lon_min], [lat_max, lon_max]] + m = folium.Map( + location=[(lat_min + lat_max) / 2, (lon_min + lon_max) / 2], zoom_start=8 + ) + img = folium.raster_layers.ImageOverlay( + name="Sentinel 2", + image=arr.transpose(1, 2, 0), + bounds=bbox, + opacity=0.9, + interactive=True, + cross_origin=False, + zindex=1, + ) - # img.add_to(m) - # folium.LayerControl().add_to(m) + img.add_to(m) + folium.LayerControl().add_to(m) - # folium_static(m) - st.image(href) + folium_static(m) + # st.image(href) def plot_vector(df): @@ -102,7 +106,11 @@ def plot_vector(df): st.set_page_config(page_title="LLLLM", page_icon="🤖", layout="wide") -st.subheader("🤖 I am Geo LLM Agent!") +st.markdown("🤖 I am Geo LLM Agent!") +st.caption( + "I have access to tools like :blue[STAC Search, OSM API, Geocode & Mercantile]. Feel free to ask me questions like - :orange[_lat,lng_ of a place, _parks/hospitals_ in a city, _walkable streets_ in a city or _satellite image_ on a particular date.]" +) + if "msgs" not in st.session_state: st.session_state.msgs = [] @@ -123,8 +131,11 @@ def plot_vector(df): openai_api_key = os.getenv("OPENAI_API_KEY") if not openai_api_key: openai_api_key = st.text_input("OpenAI API Key", type="password") + st.info( + "You can find your API key [here](https://platform.openai.com/account/api-keys)" + ) - st.subheader("OpenAI Usage") + st.subheader("OpenAI Usage this Session") total_tokens = st.empty() prompt_tokens = st.empty() completion_tokens = st.empty() @@ -153,15 +164,18 @@ def plot_vector(df): st.stop() aim_callback = AimCallbackHandler( - repo=".", - experiment_name="LLLLLM: Base Agent v0.1", + repo=os.getenv("AIM_LOGS", "."), + experiment_name="LLLLLM: OpenAI function agent v0.3", ) agent = get_agent(openai_api_key) with get_openai_callback() as cb: st_callback = StreamlitCallbackHandler(st.container()) - response = agent.run(prompt, callbacks=[st_callback, aim_callback]) + response = agent.run( + prompt, + callbacks=[st_callback, aim_callback], + ) aim_callback.flush_tracker(langchain_asset=agent, reset=False, finish=True) @@ -180,6 +194,8 @@ def plot_vector(df): total_cost.write(f"Total Cost (USD): ${st.session_state.total_cost:,.4f}") with st.chat_message(name="assistant", avatar="🤖"): + print(type(response)) + print(response) if type(response) == str: content = response st.markdown(response) @@ -188,20 +204,28 @@ def plot_vector(df): match tool: case "stac-search": - content = f"Found {len(result)} items from the catalog." - st.markdown(content) - if len(result) > 0: + if len(result) == 0: + content = "No items found." + else: + content = f"Found {len(result)} items from the catalog." plot_raster(result) + st.markdown(content) case "geometry": - content = f"Found {len(result)} geometries." - gdf = result + if type(result) is str or len(result) == 0: + content = "No geometries found." + else: + content = f"Found {len(result)} geometries." + gdf = result + plot_vector(gdf) st.markdown(content) - plot_vector(gdf) case "network": - content = f"Found {len(result)} network geometries." - ndf = result + if type(result) is str or len(result) == 0: + content = "No network geometries found." + else: + content = f"Found {len(result)} network geometries." + ndf = result + plot_vector(ndf) st.markdown(content) - plot_vector(ndf) case _: content = response st.markdown(content) diff --git a/tools/osmnx/geometry.py b/tools/osmnx/geometry.py index ecc1287..5873ab2 100644 --- a/tools/osmnx/geometry.py +++ b/tools/osmnx/geometry.py @@ -1,4 +1,4 @@ -from typing import Type, Dict +from typing import Type, Dict, Union import osmnx as ox import geopandas as gpd @@ -10,7 +10,10 @@ class PlaceWithTags(BaseModel): "Name of a place on the map and tags in OSM." place: str = Field(..., description="name of a place on the map.") - tags: Dict[str, str] = Field(..., description="open street maps tags.") + tags: Dict[str, str] = Field( + ..., + description="open street maps tags as dict", + ) class OSMnxGeometryTool(BaseTool): @@ -18,15 +21,22 @@ class OSMnxGeometryTool(BaseTool): name: str = "geometry" args_schema: Type[BaseModel] = PlaceWithTags - description: str = "Use this tool to get geometry of different features of the place like building footprints, parks, lakes, hospitals, schools etc. \ - Pass the name of the place & tags of OSM as args." + description: str = """Use this tool to get geometry of different features of the place like building footprints, parks, lakes, hospitals, schools etc. \ + Pass the name of the place & tags of OSM as args. + + Example tags: {'building': 'yes'} or {'leisure': 'park'} or {'amenity': 'hospital'} or {'amenity': 'school'} etc. + """ return_direct = True def _run(self, place: str, tags: Dict[str, str]) -> gpd.GeoDataFrame: - gdf = ox.geometries_from_place(place, tags) - gdf = gdf[gdf["geometry"].type.isin({"Polygon", "MultiPolygon"})] - gdf = gdf[["name", "geometry"]].reset_index(drop=True) - return ("geometry", gdf) + try: + gdf = ox.geometries_from_place(place, tags) + gdf = gdf[gdf["geometry"].type.isin({"Polygon", "MultiPolygon"})] + gdf = gdf[["name", "geometry"]].reset_index(drop=True) + response = ("geometry", gdf) + except Exception as e: + response = ("geometry", f"Error in parsing: {(place, tags)}") + return response def _arun(self, place: str): raise NotImplementedError diff --git a/tools/osmnx/network.py b/tools/osmnx/network.py index 24f8375..56ced52 100644 --- a/tools/osmnx/network.py +++ b/tools/osmnx/network.py @@ -25,10 +25,14 @@ class OSMnxNetworkTool(BaseTool): return_direct = True def _run(self, place: str, network_type: str) -> gpd.GeoDataFrame: - G = ox.graph_from_place(place, network_type=network_type, simplify=True) - network = utils_graph.graph_to_gdfs(G, nodes=False) - network = network[["name", "geometry"]].reset_index(drop=True) - return ("network", network) + try: + G = ox.graph_from_place(place, network_type=network_type, simplify=True) + network = utils_graph.graph_to_gdfs(G, nodes=False) + network = network[["name", "geometry"]].reset_index(drop=True) + response = ("network", network) + except Exception as e: + response = ("network", f"Error in parsing: {(place, network_type)}") + return response def _arun(self, place: str): raise NotImplementedError diff --git a/tools/stac/search.py b/tools/stac/search.py index ce8d285..c1257d6 100644 --- a/tools/stac/search.py +++ b/tools/stac/search.py @@ -12,7 +12,7 @@ class PlaceWithDatetimeAndBBox(BaseModel): "Name of a place and date." bbox: str = Field(..., description="bbox of the place") - datetime: str = Field(..., description="datetime for the stac catalog search") + datetime: str = Field(..., description="only date for the stac catalog search") class STACSearchTool(BaseTool):