forked from togethercomputer/MoA
-
Notifications
You must be signed in to change notification settings - Fork 1
/
moa.py
50 lines (39 loc) · 1.98 KB
/
moa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Mixture-of-Agents in 50 lines of code
import asyncio
import os
from together import AsyncTogether, Together
client = Together(api_key=os.environ.get("TOGETHER_API_KEY"))
async_client = AsyncTogether(api_key=os.environ.get("TOGETHER_API_KEY"))
user_prompt = "What are some fun things to do in SF?"
reference_models = [
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen1.5-72B-Chat",
"mistralai/Mixtral-8x22B-Instruct-v0.1",
"databricks/dbrx-instruct",
]
aggregator_model = "mistralai/Mixtral-8x22B-Instruct-v0.1"
aggreagator_system_prompt = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models:"""
async def run_llm(model):
"""Run a single LLM call with a reference model."""
response = await async_client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": user_prompt}],
temperature=0.7,
max_tokens=512,
)
print(model)
return response.choices[0].message.content
async def main():
results = await asyncio.gather(*[run_llm(model) for model in reference_models])
finalStream = client.chat.completions.create(
model=aggregator_model,
messages=[
{"role": "system", "content": aggreagator_system_prompt},
{"role": "user", "content": ",".join(str(element) for element in results)},
],
stream=True,
)
for chunk in finalStream:
print(chunk.choices[0].delta.content or "", end="", flush=True)
asyncio.run(main())