Skip to content

Commit

Permalink
introduce exampleSelector at surface layer
Browse files Browse the repository at this point in the history
  • Loading branch information
yuiseki committed Jan 7, 2024
1 parent be906bb commit 5f56fbf
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 28 deletions.
33 changes: 29 additions & 4 deletions src/app/api/ai/surface/route.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { NextResponse } from "next/server";
import { OpenAI, OpenAIChat } from "langchain/llms/openai";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { AIMessage, HumanMessage } from "langchain/schema";
import { loadTridentSurfaceChain } from "@/utils/langchain/chains/surface";
// using openai
import { OpenAIChat } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";

export async function POST(request: Request) {
console.log("----- ----- -----");
console.log("----- star surface -----");
console.log("----- start surface -----");

const reqJson = await request.json();
const query = reqJson.query;
Expand Down Expand Up @@ -44,8 +46,31 @@ export async function POST(request: Request) {
chatHistory: chatHistory,
});

const model = new OpenAIChat({ temperature: 0 });
const surfaceChain = loadTridentSurfaceChain({ llm: model, memory });
let embeddings: OpenAIEmbeddings;
let llm: OpenAIChat;

if (process.env.CLOUDFLARE_AI_GATEWAY) {
embeddings = new OpenAIEmbeddings({
configuration: {
baseURL: process.env.CLOUDFLARE_AI_GATEWAY + "/openai",
},
});
llm = new OpenAIChat({
configuration: {
baseURL: process.env.CLOUDFLARE_AI_GATEWAY + "/openai",
},
temperature: 0,
});
} else {
embeddings = new OpenAIEmbeddings();
llm = new OpenAIChat({ temperature: 0 });
}

const surfaceChain = await loadTridentSurfaceChain({
embeddings,
llm,
memory,
});
const surfaceResult = await surfaceChain.call({ input: query });

console.log("Human:", query);
Expand Down
12 changes: 8 additions & 4 deletions src/utils/langchain/chains/surface/index.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import { BaseLanguageModel } from "langchain/dist/base_language";
import { TRIDENT_SURFACE_PROMPT } from "./prompt";
import { loadTridentSurfacePrompt } from "./prompt";
import { ConversationChain, LLMChain } from "langchain/chains";
import { BaseMemory, BufferMemory } from "langchain/memory";
import { Embeddings } from "langchain/embeddings/base";

export const loadTridentSurfaceChain = ({
export const loadTridentSurfaceChain = async ({
embeddings,
llm,
memory,
}: {
embeddings: Embeddings;
llm: BaseLanguageModel;
memory?: BaseMemory;
}): LLMChain => {
}): Promise<LLMChain> => {
if (memory === undefined) {
memory = new BufferMemory();
}
const prompt = await loadTridentSurfacePrompt(embeddings);
const chain = new ConversationChain({
llm: llm,
prompt: TRIDENT_SURFACE_PROMPT,
prompt: prompt,
memory: memory,
});
return chain;
Expand Down
79 changes: 59 additions & 20 deletions src/utils/langchain/chains/surface/prompt.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,55 @@
import { PromptTemplate } from "langchain/prompts";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import {
SemanticSimilarityExampleSelector,
PromptTemplate,
FewShotPromptTemplate,
} from "langchain/prompts";
import { Embeddings } from "langchain/embeddings/base";

export const TRIDENT_SURFACE_PROMPT = new PromptTemplate({
template: `Your name is TRIDENT, You are an interactive web maps generating assistant. You interact with the human, asking step-by-step about the areas and concerns of the map they want to create.
export const tridentSurfaceExampleList: Array<{
input: string;
output: string;
}> = [
{
input: "Show map of New York City.",
output:
"I copy. I'm generating maps that shows the city of New York based on OpenStreetMap data. Please wait a while...",
},
{
input: "ニューヨークの地図を表示して",
output:
"了解しました。OpenStreetMapのデータに基づいてニューヨーク市を表示する地図を作成しています。しばらくお待ちください……",
},
{
input: "显示纽约地图",
output: "知道了。我正在生成基于OpenStreetMap数据的纽约市地图。请稍等……",
},
];

export const loadTridentSurfacePrompt = async (embeddings: Embeddings) => {
const memoryVectorStore = new MemoryVectorStore(embeddings);
const exampleSelector = new SemanticSimilarityExampleSelector({
vectorStore: memoryVectorStore,
k: 3,
inputKeys: ["input"],
});
const examplePrompt = PromptTemplate.fromTemplate(
`Human:
{input}
AI:
{output}
`
);

for (const example of tridentSurfaceExampleList) {
await exampleSelector.addExample(example);
}

const dynamicPrompt = new FewShotPromptTemplate({
exampleSelector: exampleSelector,
examplePrompt: examplePrompt,
prefix: `Your name is TRIDENT, You are an interactive web maps generating assistant. You interact with the human, asking step-by-step about the areas and concerns of the map they want to create.
You will always reply according to the following rules:
- You MUST ALWAYS reply IN THE LANGUAGE WHICH HUMAN IS WRITING.
Expand All @@ -18,23 +66,14 @@ You will always reply according to the following rules:
- Without when human want to remove, delete or limit maps, Do not forget previous areas and concerns.
- If you can answer human requests, you MUST ALWAYS notify to human that you are generating maps based on OpenStreetMap data.
Examples:
===
Input text:
Human: Show a map of New York City.
AI: I copy. I'm generating maps that shows the city of New York based on OpenStreetMap data. Please wait a while...
Input text:
Human: ニューヨークの地図を表示して
AI: 了解しました。OpenStreetMapのデータに基づいてニューヨーク市を表示する地図を作成しています。しばらくお待ちください……
Input text:
Human: 显示纽约地图
AI: 知道了。我正在生成基于OpenStreetMap数据的纽约市地图。请稍等……
### Examples: ###`,
suffix: `
### Current conversation: ###
Current conversation:
{history}
Human: {input}
AI:`,
inputVariables: ["history", "input"],
});
AI: `,
inputVariables: ["history", "input"],
});
return dynamicPrompt;
};

1 comment on commit 5f56fbf

@vercel
Copy link

@vercel vercel bot commented on 5f56fbf Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.