Skip to content

Commit

Permalink
Merge pull request #175 from dbpunk-labs/174-unify-the-model-serving
Browse files Browse the repository at this point in the history
feat: add og serving
  • Loading branch information
imotai authored Oct 29, 2023
2 parents f46fa18 + ac3657b commit 9472902
Show file tree
Hide file tree
Showing 19 changed files with 1,308 additions and 535 deletions.
17 changes: 13 additions & 4 deletions agent/src/og_agent/agent_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# SPDX-License-Identifier: Elastic-2.0

import sys
import asyncio
import uvicorn
import json
Expand All @@ -19,11 +20,19 @@
from fastapi.param_functions import Header, Annotated
from dotenv import dotenv_values

logger = logging.getLogger(__name__)

# the agent config
# the api server config
config = dotenv_values(".env")

LOG_LEVEL = (
logging.DEBUG if config.get("log_level", "info") == "debug" else logging.INFO
)
logging.basicConfig(
level=LOG_LEVEL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

app = FastAPI()
# the agent endpoint
listen_addr = "%s:%s" % (
Expand All @@ -32,7 +41,6 @@
)
if config.get("rpc_host", "") == "0.0.0.0":
listen_addr = "127.0.0.1:%s" % config.get("rpc_port", "9528")
logger.info(f"connect the agent server at {listen_addr}")
agent_sdk = AgentProxySDK(listen_addr)


Expand Down Expand Up @@ -191,6 +199,7 @@ async def process_task(


async def run_server():
logger.info(f"connect the agent server at {listen_addr}")
port = int(config.get("rpc_port", "9528")) + 1
server_config = uvicorn.Config(
app, host=config.get("rpc_host", "127.0.0.1"), port=port
Expand Down
14 changes: 6 additions & 8 deletions agent/src/og_agent/agent_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,23 @@
""" """
import json
from .prompt import OCTOGEN_FUNCTION_SYSTEM, OCTOGEN_CODELLAMA_SYSTEM
from .codellama_agent import CodellamaAgent
from .llama_agent import LlamaAgent
from .openai_agent import OpenaiAgent
from .codellama_client import CodellamaClient
from .llama_client import LlamaClient
from .mock_agent import MockAgent


def build_codellama_agent(endpoint, key, sdk, grammer_path):
def build_llama_agent(endpoint, key, sdk, grammer_path):
"""
build codellama agent
build llama agent
"""
with open(grammer_path, "r") as fd:
grammar = fd.read()

client = CodellamaClient(
endpoint, key, OCTOGEN_CODELLAMA_SYSTEM, "Octogen", "User", grammar
)
client = LlamaClient(endpoint, key, grammar)

# init the agent
return CodellamaAgent(client, sdk)
return LlamaAgent(client, sdk)


def build_openai_agent(sdk, model_name, is_azure=True):
Expand Down
6 changes: 3 additions & 3 deletions agent/src/og_agent/agent_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from og_sdk.kernel_sdk import KernelSDK
from og_sdk.utils import parse_image_filename
from .agent_llm import LLMManager
from .agent_builder import build_mock_agent, build_openai_agent, build_codellama_agent
from .agent_builder import build_mock_agent, build_openai_agent, build_llama_agent
import databases
import orm
from datetime import datetime
Expand Down Expand Up @@ -116,11 +116,11 @@ async def add_kernel(
agent = build_mock_agent(sdk, config["cases_path"])
self.agents[request.key] = {"sdk": sdk, "agent": agent}
elif config["llm_key"] == "codellama":
logger.info(f"create a codellama agent {request.endpoint}")
logger.info(f"create a llama agent {request.endpoint}")
grammer_path = os.path.join(
pathlib.Path(__file__).parent.resolve(), "grammar.bnf"
)
agent = build_codellama_agent(
agent = build_llama_agent(
config["llama_api_base"], config["llama_api_key"], sdk, grammer_path
)
self.agents[request.key] = {"sdk": sdk, "agent": agent}
Expand Down
254 changes: 254 additions & 0 deletions agent/src/og_agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@

""" """
import json
import io
import logging
import time
from typing import List
from pydantic import BaseModel, Field
from og_proto.kernel_server_pb2 import ExecuteResponse
from og_proto.agent_server_pb2 import TaskResponse, ContextState
from og_sdk.utils import parse_image_filename, process_char_stream
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer, TypingContent
from .tokenizer import tokenize
import tiktoken

encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -55,6 +60,255 @@ class BaseAgent:

def __init__(self, sdk):
self.kernel_sdk = sdk
self.model_name = ""

def _merge_delta_for_function_call(self, message, delta):
if len(message.keys()) == 0:
message.update(delta)
return
if "function_call" not in message:
message["function_call"] = delta["function_call"]
return
old_arguments = message["function_call"].get("arguments", "")
if delta["function_call"]["arguments"]:
message["function_call"]["arguments"] = (
old_arguments + delta["function_call"]["arguments"]
)

def _merge_delta_for_content(self, message, delta):
if not delta:
return
content = message.get("content", "")
if delta.get("content"):
message["content"] = content + delta["content"]

def _parse_arguments(
self,
arguments,
is_code=False,
first_field_name="explanation",
second_field_name="code",
):
"""
parse the partial key with string value from json
"""
if is_code:
return TypingState.CODE, "", arguments
state = TypingState.START
explanation_str = ""
code_str = ""
logger.debug(f"the arguments {arguments}")
for token_state, token in tokenize(io.StringIO(arguments)):
if token_state == None:
if state == TypingState.EXPLANATION and token[0] == 1:
explanation_str = token[1]
state = TypingState.START
if state == TypingState.CODE and token[0] == 1:
code_str = token[1]
state = TypingState.START
if token[1] == first_field_name:
state = TypingState.EXPLANATION
if token[1] == second_field_name:
state = TypingState.CODE
else:
# String
if token_state == 9 and state == TypingState.EXPLANATION:
explanation_str = "".join(token)
elif token_state == 9 and state == TypingState.CODE:
code_str = "".join(token)
return (state, explanation_str, code_str)

def _get_message_token_count(self, message):
response_token_count = 0
if "function_call" in message and message["function_call"]:
arguments = message["function_call"].get("arguments", "")
response_token_count += len(encoding.encode(arguments))
if "content" in message and message["content"]:
response_token_count += len(encoding.encode(message.get("content")))
return response_token_count

async def _read_function_call_message(
self, message, queue, old_text_content, old_code_content, task_context, task_opt
):
typing_language = "text"
if message["function_call"].get("name", "") in [
"execute_python_code",
"python",
]:
typing_language = "python"
elif message["function_call"].get("name", "") == "execute_bash_code":
typing_language = "bash"
is_code = False
if message["function_call"].get("name", "") == "python":
is_code = True
arguments = message["function_call"].get("arguments", "")
return await self._send_typing_message(
arguments,
queue,
old_text_content,
old_code_content,
typing_language,
task_context,
task_opt,
is_code=is_code,
)

async def _read_json_message(
self, message, queue, old_text_content, old_code_content, task_context, task_opt
):
arguments = message.get("content", "")
typing_language = "text"
if arguments.find("execute_python_code") >= 0:
typing_language = "python"
elif arguments.find("execute_bash_code") >= 0:
typing_language = "bash"

return await self._send_typing_message(
arguments,
queue,
old_text_content,
old_code_content,
typing_language,
task_context,
task_opt,
)

async def _send_typing_message(
self,
arguments,
queue,
old_text_content,
old_code_content,
language,
task_context,
task_opt,
is_code=False,
):
"""
send the typing message to the client
"""
(state, explanation_str, code_str) = self._parse_arguments(arguments, is_code)
logger.debug(
f"argument explanation:{explanation_str} code:{code_str} text_content:{old_text_content}"
)
if explanation_str and old_text_content != explanation_str:
typed_chars = explanation_str[len(old_text_content) :]
new_text_content = explanation_str
if task_opt.streaming and len(typed_chars) > 0:
task_response = TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=TypingContent(content=typed_chars, language="text"),
)
await queue.put(task_response)
return new_text_content, old_code_content
if code_str and old_code_content != code_str:
typed_chars = code_str[len(old_code_content) :]
code_content = code_str
if task_opt.streaming and len(typed_chars) > 0:
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeCode,
typing_content=TypingContent(
content=typed_chars, language=language
),
)
)
return old_text_content, code_content
return old_text_content, old_code_content

async def extract_message(
self,
response_generator,
queue,
rpc_context,
task_context,
task_opt,
start_time,
is_json_format=False,
):
"""
extract the chunk from the response generator
"""
message = {}
text_content = ""
code_content = ""
context_output_token_count = task_context.output_token_count
start_time = time.time()
async for chunk in response_generator:
if rpc_context.done():
logger.debug("the client has cancelled the request")
break
if not chunk["choices"]:
continue
logger.debug(f"the chunk {chunk}")
task_context.llm_name = chunk.get("model", "")
self.model_name = chunk.get("model", "")
delta = chunk["choices"][0]["delta"]
if "function_call" in delta:
self._merge_delta_for_function_call(message, delta)
response_token_count = self._get_message_token_count(message)
task_context.output_token_count = (
response_token_count + context_output_token_count
)
task_context.llm_response_duration += int(
(time.time() - start_time) * 1000
)
start_time = time.time()
(
new_text_content,
new_code_content,
) = await self._read_function_call_message(
message,
queue,
text_content,
code_content,
task_context,
task_opt,
)
text_content = new_text_content
code_content = new_code_content
else:
self._merge_delta_for_content(message, delta)
task_context.llm_response_duration += int(
(time.time() - start_time) * 1000
)
start_time = time.time()
if message.get("content") != None:
response_token_count = self._get_message_token_count(message)
task_context.output_token_count = (
response_token_count + context_output_token_count
)
if is_json_format:
(
new_text_content,
new_code_content,
) = await self._read_json_message(
message,
queue,
text_content,
code_content,
task_context,
task_opt,
)
text_content = new_text_content
code_content = new_code_content

elif task_opt.streaming and delta.get("content"):
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=TypingContent(
content=delta["content"], language="text"
),
)
)
logger.info(
f"call the {self.model_name} with input token {task_context.input_token_count} and output token count {task_context.output_token_count}"
)
return message

async def call_function(self, code, context, task_context=None):
"""
Expand Down
3 changes: 1 addition & 2 deletions agent/src/og_agent/base_stream_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ def __init__(self, endpoint, key):

async def arun(self, request_data):
logging.debug(f"{request_data}")
data = json.dumps(request_data)
headers = {"Authorization": self.key}
async with aiohttp.ClientSession(
headers=headers, raise_for_status=True
) as session:
async with session.post(self.endpoint, data=data) as r:
async with session.post(self.endpoint, json=request_data) as r:
async for line in r.content:
if line:
yield line
Loading

0 comments on commit 9472902

Please sign in to comment.