diff --git a/src/app/api/ai/surface/route.ts b/src/app/api/ai/surface/route.ts index ed7a0628..d6782278 100644 --- a/src/app/api/ai/surface/route.ts +++ b/src/app/api/ai/surface/route.ts @@ -1,12 +1,14 @@ import { NextResponse } from "next/server"; -import { OpenAI, OpenAIChat } from "langchain/llms/openai"; import { BufferMemory, ChatMessageHistory } from "langchain/memory"; import { AIMessage, HumanMessage } from "langchain/schema"; import { loadTridentSurfaceChain } from "@/utils/langchain/chains/surface"; +// using openai +import { OpenAIChat } from "langchain/llms/openai"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; export async function POST(request: Request) { console.log("----- ----- -----"); - console.log("----- star surface -----"); + console.log("----- start surface -----"); const reqJson = await request.json(); const query = reqJson.query; @@ -44,8 +46,31 @@ export async function POST(request: Request) { chatHistory: chatHistory, }); - const model = new OpenAIChat({ temperature: 0 }); - const surfaceChain = loadTridentSurfaceChain({ llm: model, memory }); + let embeddings: OpenAIEmbeddings; + let llm: OpenAIChat; + + if (process.env.CLOUDFLARE_AI_GATEWAY) { + embeddings = new OpenAIEmbeddings({ + configuration: { + baseURL: process.env.CLOUDFLARE_AI_GATEWAY + "/openai", + }, + }); + llm = new OpenAIChat({ + configuration: { + baseURL: process.env.CLOUDFLARE_AI_GATEWAY + "/openai", + }, + temperature: 0, + }); + } else { + embeddings = new OpenAIEmbeddings(); + llm = new OpenAIChat({ temperature: 0 }); + } + + const surfaceChain = await loadTridentSurfaceChain({ + embeddings, + llm, + memory, + }); const surfaceResult = await surfaceChain.call({ input: query }); console.log("Human:", query); diff --git a/src/utils/langchain/chains/surface/index.ts b/src/utils/langchain/chains/surface/index.ts index 74404240..d3f29ae3 100644 --- a/src/utils/langchain/chains/surface/index.ts +++ b/src/utils/langchain/chains/surface/index.ts @@ -1,21 +1,25 @@ import { BaseLanguageModel } from "langchain/dist/base_language"; -import { TRIDENT_SURFACE_PROMPT } from "./prompt"; +import { loadTridentSurfacePrompt } from "./prompt"; import { ConversationChain, LLMChain } from "langchain/chains"; import { BaseMemory, BufferMemory } from "langchain/memory"; +import { Embeddings } from "langchain/embeddings/base"; -export const loadTridentSurfaceChain = ({ +export const loadTridentSurfaceChain = async ({ + embeddings, llm, memory, }: { + embeddings: Embeddings; llm: BaseLanguageModel; memory?: BaseMemory; -}): LLMChain => { +}): Promise => { if (memory === undefined) { memory = new BufferMemory(); } + const prompt = await loadTridentSurfacePrompt(embeddings); const chain = new ConversationChain({ llm: llm, - prompt: TRIDENT_SURFACE_PROMPT, + prompt: prompt, memory: memory, }); return chain; diff --git a/src/utils/langchain/chains/surface/prompt.ts b/src/utils/langchain/chains/surface/prompt.ts index cde68575..bfea169e 100644 --- a/src/utils/langchain/chains/surface/prompt.ts +++ b/src/utils/langchain/chains/surface/prompt.ts @@ -1,7 +1,55 @@ -import { PromptTemplate } from "langchain/prompts"; +import { MemoryVectorStore } from "langchain/vectorstores/memory"; +import { + SemanticSimilarityExampleSelector, + PromptTemplate, + FewShotPromptTemplate, +} from "langchain/prompts"; +import { Embeddings } from "langchain/embeddings/base"; -export const TRIDENT_SURFACE_PROMPT = new PromptTemplate({ - template: `Your name is TRIDENT, You are an interactive web maps generating assistant. You interact with the human, asking step-by-step about the areas and concerns of the map they want to create. +export const tridentSurfaceExampleList: Array<{ + input: string; + output: string; +}> = [ + { + input: "Show map of New York City.", + output: + "I copy. I'm generating maps that shows the city of New York based on OpenStreetMap data. Please wait a while...", + }, + { + input: "ニューヨークの地図を表示して", + output: + "了解しました。OpenStreetMapのデータに基づいてニューヨーク市を表示する地図を作成しています。しばらくお待ちください……", + }, + { + input: "显示纽约地图", + output: "知道了。我正在生成基于OpenStreetMap数据的纽约市地图。请稍等……", + }, +]; + +export const loadTridentSurfacePrompt = async (embeddings: Embeddings) => { + const memoryVectorStore = new MemoryVectorStore(embeddings); + const exampleSelector = new SemanticSimilarityExampleSelector({ + vectorStore: memoryVectorStore, + k: 3, + inputKeys: ["input"], + }); + const examplePrompt = PromptTemplate.fromTemplate( + `Human: +{input} + +AI: +{output} +` + ); + + for (const example of tridentSurfaceExampleList) { + await exampleSelector.addExample(example); + } + + const dynamicPrompt = new FewShotPromptTemplate({ + exampleSelector: exampleSelector, + examplePrompt: examplePrompt, + prefix: `Your name is TRIDENT, You are an interactive web maps generating assistant. You interact with the human, asking step-by-step about the areas and concerns of the map they want to create. You will always reply according to the following rules: - You MUST ALWAYS reply IN THE LANGUAGE WHICH HUMAN IS WRITING. @@ -18,23 +66,14 @@ You will always reply according to the following rules: - Without when human want to remove, delete or limit maps, Do not forget previous areas and concerns. - If you can answer human requests, you MUST ALWAYS notify to human that you are generating maps based on OpenStreetMap data. -Examples: -=== -Input text: -Human: Show a map of New York City. -AI: I copy. I'm generating maps that shows the city of New York based on OpenStreetMap data. Please wait a while... - -Input text: -Human: ニューヨークの地図を表示して -AI: 了解しました。OpenStreetMapのデータに基づいてニューヨーク市を表示する地図を作成しています。しばらくお待ちください…… - -Input text: -Human: 显示纽约地图 -AI: 知道了。我正在生成基于OpenStreetMap数据的纽约市地图。请稍等…… +### Examples: ###`, + suffix: ` +### Current conversation: ### -Current conversation: {history} Human: {input} -AI:`, - inputVariables: ["history", "input"], -}); +AI: `, + inputVariables: ["history", "input"], + }); + return dynamicPrompt; +};