diff --git a/.gitignore b/.gitignore index 479d1527..7ed18ff0 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ evals/**/public lib/dom/bundle.js evals/public *.tgz +evals/playground.ts \ No newline at end of file diff --git a/evals/index.eval.ts b/evals/index.eval.ts index 14ed5dc1..2c5073be 100644 --- a/evals/index.eval.ts +++ b/evals/index.eval.ts @@ -269,7 +269,10 @@ const peeler_complex = async () => { await stagehand.page.goto(`https://chefstoys.com/`, { timeout: 60000 }); await stagehand.act({ - action: "search for peelers", + action: "search for %search_query%", + variables: { + search_query: "peeler", + }, }); await stagehand.act({ diff --git a/evals/playground.ts b/evals/playground.ts index 3f15f7ab..e69de29b 100644 --- a/evals/playground.ts +++ b/evals/playground.ts @@ -1,126 +0,0 @@ -import { Stagehand } from "../lib"; -import { z } from "zod"; -import { EvalLogger } from "./utils"; - -// eval failing -const homedepot = async () => { - const stagehand = new Stagehand({ - env: "LOCAL", - verbose: 1, - debugDom: true, - headless: process.env.HEADLESS !== "false", - }); - - await stagehand.init(); - - try { - await stagehand.page.goto("https://www.homedepot.com/"); - - await stagehand.act({ action: "search for gas grills" }); - - await stagehand.act({ action: "click on the first gas grill" }); - - await stagehand.act({ action: "click on the Product Details" }); - - await stagehand.act({ action: "find the Primary Burner BTU" }); - - const productSpecs = await stagehand.extract({ - instruction: "Extract the Primary Burner BTU of the product", - schema: z.object({ - productSpecs: z - .array( - z.object({ - burnerBTU: z.string().describe("Primary Burner BTU"), - }), - ) - .describe("Gas grill Primary Burner BTU"), - }), - modelName: "gpt-4o-2024-08-06", - }); - console.log("The gas grill primary burner BTU is:", productSpecs); - - if ( - !productSpecs || - !productSpecs.productSpecs || - productSpecs.productSpecs.length === 0 - ) { - return false; - } - - return true; - } catch (error) { - console.error(`Error in homedepot function: ${error.message}`); - return false; - } finally { - await stagehand.context.close(); - } -}; - -const vanta = async () => { - const logger = new EvalLogger(); - - const stagehand = new Stagehand({ - env: "LOCAL", - headless: process.env.HEADLESS !== "false", - logger: (message: any) => { - logger.log(message); - }, - verbose: 2, - }); - - logger.init(stagehand); - - const { debugUrl, sessionUrl } = await stagehand.init(); - - await stagehand.page.goto("https://www.vanta.com/"); - - const observations = await stagehand.observe({ - instruction: "find the text for the request demo button", - }); - - console.log("Observations:", observations); - - if (observations.length === 0) { - await stagehand.context.close(); - return { - _success: false, - observations, - debugUrl, - sessionUrl, - logs: logger.getLogs(), - }; - } - - const observationResult = await stagehand.page - .locator(observations[0].selector) - .first() - .innerHTML(); - - const expectedLocator = `body > div.page-wrapper > div.nav_component > div.nav_element.w-nav > div.padding-global > div > div > nav > div.nav_cta-wrapper.is-new > a.nav_cta-button-desktop.is-smaller.w-button`; - - const expectedResult = await stagehand.page - .locator(expectedLocator) - .first() - .innerHTML(); - - await stagehand.context.close(); - - return { - _success: observationResult == expectedResult, - expected: expectedResult, - actual: observationResult, - debugUrl, - sessionUrl, - logs: logger.getLogs(), - }; -}; - -async function main() { - // const homedepotResult = await homedepot(); - const vantaResult = await vanta(); - - // console.log("Result:", homedepotResult); - console.log("Result:", vantaResult); -} - -main().catch(console.error); diff --git a/examples/example.ts b/examples/example.ts index 4303391e..b1696feb 100644 --- a/examples/example.ts +++ b/examples/example.ts @@ -21,6 +21,7 @@ async function example() { }); console.log(`Our favorite contributor is ${contributor.username}`); } + (async () => { await example(); })(); diff --git a/lib/cache/ActionCache.ts b/lib/cache/ActionCache.ts new file mode 100644 index 00000000..f7398cf6 --- /dev/null +++ b/lib/cache/ActionCache.ts @@ -0,0 +1,137 @@ +import { BaseCache, CacheEntry } from "./BaseCache"; + +export interface PlaywrightCommand { + method: string; + args: string[]; +} + +export interface ActionEntry extends CacheEntry { + data: { + playwrightCommand: PlaywrightCommand; + componentString: string; + xpaths: string[]; + newStepString: string; + completed: boolean; + previousSelectors: string[]; + action: string; + }; +} + +/** + * ActionCache handles logging and retrieving actions along with their Playwright commands. + */ +export class ActionCache extends BaseCache { + constructor( + logger: (message: { + category?: string; + message: string; + level?: number; + }) => void, + cacheDir?: string, + cacheFile?: string, + ) { + super(logger, cacheDir, cacheFile || "action_cache.json"); + } + + public async addActionStep({ + url, + action, + previousSelectors, + playwrightCommand, + componentString, + xpaths, + newStepString, + completed, + requestId, + }: { + url: string; + action: string; + previousSelectors: string[]; + playwrightCommand: PlaywrightCommand; + componentString: string; + requestId: string; + xpaths: string[]; + newStepString: string; + completed: boolean; + }): Promise { + this.logger({ + category: "action_cache", + message: `Adding action step to cache: ${action}, requestId: ${requestId}, url: ${url}, previousSelectors: ${previousSelectors}`, + level: 1, + }); + + await this.set( + { url, action, previousSelectors }, + { + playwrightCommand, + componentString, + xpaths, + newStepString, + completed, + previousSelectors, + action, + }, + requestId, + ); + } + + /** + * Retrieves all actions for a specific trajectory. + * @param trajectoryId - Unique identifier for the trajectory. + * @param requestId - The identifier for the current request. + * @returns An array of TrajectoryEntry objects or null if not found. + */ + public async getActionStep({ + url, + action, + previousSelectors, + requestId, + }: { + url: string; + action: string; + previousSelectors: string[]; + requestId: string; + }): Promise { + const data = await super.get({ url, action, previousSelectors }, requestId); + if (!data) { + return null; + } + + return data; + } + + public async removeActionStep(cacheHashObj: { + url: string; + action: string; + previousSelectors: string[]; + requestId: string; + }): Promise { + await super.delete(cacheHashObj); + } + + /** + * Clears all actions for a specific trajectory. + * @param trajectoryId - Unique identifier for the trajectory. + * @param requestId - The identifier for the current request. + */ + public async clearAction(requestId: string): Promise { + await super.deleteCacheForRequestId(requestId); + this.logger({ + category: "action_cache", + message: `Cleared action for ID: ${requestId}`, + level: 1, + }); + } + + /** + * Resets the entire action cache. + */ + public async resetCache(): Promise { + await super.resetCache(); + this.logger({ + category: "action_cache", + message: "Action cache has been reset.", + level: 1, + }); + } +} diff --git a/lib/cache/BaseCache.ts b/lib/cache/BaseCache.ts new file mode 100644 index 00000000..1c4f1119 --- /dev/null +++ b/lib/cache/BaseCache.ts @@ -0,0 +1,438 @@ +import * as fs from "fs"; +import * as path from "path"; +import * as crypto from "crypto"; + +export interface CacheEntry { + timestamp: number; + data: any; + requestId: string; +} + +export interface CacheStore { + [key: string]: CacheEntry; +} + +export class BaseCache { + private readonly CACHE_MAX_AGE_MS = 7 * 24 * 60 * 60 * 1000; // 1 week in milliseconds + private readonly CLEANUP_PROBABILITY = 0.01; // 1% chance + + protected cacheDir: string; + protected cacheFile: string; + protected lockFile: string; + protected logger: (message: { + category?: string; + message: string; + level?: number; + }) => void; + + private readonly LOCK_TIMEOUT_MS = 1_000; + protected lockAcquired = false; + protected lockAcquireFailures = 0; + + // Added for request ID tracking + protected requestIdToUsedHashes: { [key: string]: string[] } = {}; + + constructor( + logger: (message: { + category?: string; + message: string; + level?: number; + }) => void, + cacheDir: string = path.join(process.cwd(), "tmp", ".cache"), + cacheFile: string = "cache.json", + ) { + this.logger = logger; + this.cacheDir = cacheDir; + this.cacheFile = path.join(cacheDir, cacheFile); + this.lockFile = path.join(cacheDir, "cache.lock"); + this.ensureCacheDirectory(); + this.setupProcessHandlers(); + } + + private setupProcessHandlers(): void { + const releaseLockAndExit = () => { + this.releaseLock(); + process.exit(); + }; + + process.on("exit", releaseLockAndExit); + process.on("SIGINT", releaseLockAndExit); + process.on("SIGTERM", releaseLockAndExit); + process.on("uncaughtException", (err) => { + this.logger({ + category: "base_cache", + message: `Uncaught exception: ${err}`, + level: 2, + }); + if (this.lockAcquired) { + releaseLockAndExit(); + } + }); + } + + protected ensureCacheDirectory(): void { + if (!fs.existsSync(this.cacheDir)) { + fs.mkdirSync(this.cacheDir, { recursive: true }); + this.logger({ + category: "base_cache", + message: `Created cache directory at ${this.cacheDir}`, + level: 1, + }); + } + } + + protected createHash(data: any): string { + const hash = crypto.createHash("sha256"); + return hash.update(JSON.stringify(data)).digest("hex"); + } + + protected sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); + } + + public async acquireLock(): Promise { + const startTime = Date.now(); + while (Date.now() - startTime < this.LOCK_TIMEOUT_MS) { + try { + if (fs.existsSync(this.lockFile)) { + const lockAge = Date.now() - fs.statSync(this.lockFile).mtimeMs; + if (lockAge > this.LOCK_TIMEOUT_MS) { + fs.unlinkSync(this.lockFile); + this.logger({ + category: "base_cache", + message: "Stale lock file removed", + level: 1, + }); + } + } + + fs.writeFileSync(this.lockFile, process.pid.toString(), { flag: "wx" }); + this.lockAcquireFailures = 0; + this.lockAcquired = true; + this.logger({ + category: "base_cache", + message: "Lock acquired", + level: 1, + }); + return true; + } catch (error) { + await this.sleep(5); + } + } + this.logger({ + category: "base_cache", + message: "Failed to acquire lock after timeout", + level: 2, + }); + this.lockAcquireFailures++; + if (this.lockAcquireFailures >= 3) { + this.logger({ + category: "base_cache", + message: + "Failed to acquire lock 3 times in a row. Releasing lock manually.", + level: 1, + }); + this.releaseLock(); + } + return false; + } + + public releaseLock(): void { + try { + if (fs.existsSync(this.lockFile)) { + fs.unlinkSync(this.lockFile); + this.logger({ + category: "base_cache", + message: "Lock released", + level: 1, + }); + } + this.lockAcquired = false; + } catch (error) { + this.logger({ + category: "base_cache", + message: `Error releasing lock: ${error}`, + level: 2, + }); + } + } + + /** + * Cleans up stale cache entries that exceed the maximum age. + */ + public async cleanupStaleEntries(): Promise { + if (!(await this.acquireLock())) { + this.logger({ + category: "llm_cache", + message: "Failed to acquire lock for cleanup", + level: 2, + }); + return; + } + + try { + const cache = this.readCache(); + const now = Date.now(); + let entriesRemoved = 0; + + for (const [hash, entry] of Object.entries(cache)) { + if (now - entry.timestamp > this.CACHE_MAX_AGE_MS) { + delete cache[hash]; + entriesRemoved++; + } + } + + if (entriesRemoved > 0) { + this.writeCache(cache); + this.logger({ + category: "llm_cache", + message: `Cleaned up ${entriesRemoved} stale cache entries`, + level: 1, + }); + } + } catch (error) { + this.logger({ + category: "llm_cache", + message: `Error during cache cleanup: ${error}`, + level: 2, + }); + } finally { + this.releaseLock(); + } + } + + protected readCache(): CacheStore { + if (fs.existsSync(this.cacheFile)) { + try { + const data = fs.readFileSync(this.cacheFile, "utf-8"); + return JSON.parse(data) as CacheStore; + } catch (error) { + this.logger({ + category: "base_cache", + message: `Error reading cache file: ${error}. Resetting cache.`, + level: 1, + }); + this.resetCache(); + return {}; + } + } + return {}; + } + + protected writeCache(cache: CacheStore): void { + try { + fs.writeFileSync(this.cacheFile, JSON.stringify(cache, null, 2)); + this.logger({ + category: "base_cache", + message: "Cache written to file", + level: 1, + }); + } catch (error) { + this.logger({ + category: "base_cache", + message: `Error writing cache file: ${error}`, + level: 2, + }); + } finally { + this.releaseLock(); + } + } + + /** + * Retrieves data from the cache based on the provided options. + * @param hashObj - The options used to generate the cache key. + * @param requestId - The identifier for the current request. + * @returns The cached data if available, otherwise null. + */ + public async get( + hashObj: Record | string, + requestId: string, + ): Promise { + if (!(await this.acquireLock())) { + this.logger({ + category: "base_cache", + message: "Failed to acquire lock for getting cache", + level: 2, + }); + return null; + } + + try { + const hash = this.createHash(hashObj); + const cache = this.readCache(); + + if (cache[hash]) { + this.trackRequestIdUsage(requestId, hash); + return cache[hash].data; + } + return null; + } catch (error) { + this.logger({ + category: "base_cache", + message: `Error getting cache: ${error}. Resetting cache.`, + level: 1, + }); + + this.resetCache(); + return null; + } finally { + this.releaseLock(); + } + } + + /** + * Stores data in the cache based on the provided options and requestId. + * @param hashObj - The options used to generate the cache key. + * @param data - The data to be cached. + * @param requestId - The identifier for the cache entry. + */ + public async set( + hashObj: Record, + data: T["data"], + requestId: string, + ): Promise { + if (!(await this.acquireLock())) { + this.logger({ + category: "base_cache", + message: "Failed to acquire lock for setting cache", + level: 2, + }); + return; + } + + try { + const hash = this.createHash(hashObj); + const cache = this.readCache(); + cache[hash] = { + data, + timestamp: Date.now(), + requestId, + }; + + this.writeCache(cache); + this.trackRequestIdUsage(requestId, hash); + } catch (error) { + this.logger({ + category: "base_cache", + message: `Error setting cache: ${error}. Resetting cache.`, + level: 1, + }); + + this.resetCache(); + } finally { + this.releaseLock(); + + if (Math.random() < this.CLEANUP_PROBABILITY) { + this.cleanupStaleEntries(); + } + } + } + + public async delete(hashObj: Record): Promise { + if (!(await this.acquireLock())) { + this.logger({ + category: "base_cache", + message: "Failed to acquire lock for removing cache entry", + level: 2, + }); + return; + } + + try { + const hash = this.createHash(hashObj); + const cache = this.readCache(); + + if (cache[hash]) { + delete cache[hash]; + this.writeCache(cache); + } else { + this.logger({ + category: "base_cache", + message: "Cache entry not found to delete", + level: 1, + }); + } + } catch (error) { + this.logger({ + category: "base_cache", + message: `Error removing cache entry: ${error}`, + level: 2, + }); + } finally { + this.releaseLock(); + } + } + + /** + * Tracks the usage of a hash with a specific requestId. + * @param requestId - The identifier for the current request. + * @param hash - The cache key hash. + */ + protected trackRequestIdUsage(requestId: string, hash: string): void { + this.requestIdToUsedHashes[requestId] ??= []; + this.requestIdToUsedHashes[requestId].push(hash); + } + + /** + * Deletes all cache entries associated with a specific requestId. + * @param requestId - The identifier for the request whose cache entries should be deleted. + */ + public async deleteCacheForRequestId(requestId: string): Promise { + if (!(await this.acquireLock())) { + this.logger({ + category: "base_cache", + message: "Failed to acquire lock for deleting cache", + level: 2, + }); + return; + } + try { + const cache = this.readCache(); + const hashes = this.requestIdToUsedHashes[requestId] ?? []; + let entriesRemoved = 0; + for (const hash of hashes) { + if (cache[hash]) { + delete cache[hash]; + entriesRemoved++; + } + } + if (entriesRemoved > 0) { + this.writeCache(cache); + } else { + this.logger({ + category: "base_cache", + message: `No cache entries found for requestId ${requestId}`, + level: 1, + }); + } + // Remove the requestId from the mapping after deletion + delete this.requestIdToUsedHashes[requestId]; + } catch (error) { + this.logger({ + category: "base_cache", + message: `Error deleting cache for requestId ${requestId}: ${error}`, + level: 2, + }); + } finally { + this.releaseLock(); + } + } + + /** + * Resets the entire cache by clearing the cache file. + */ + public resetCache(): void { + try { + fs.writeFileSync(this.cacheFile, "{}"); + this.requestIdToUsedHashes = {}; // Reset requestId tracking + } catch (error) { + this.logger({ + category: "base_cache", + message: `Error resetting cache: ${error}`, + level: 2, + }); + } finally { + this.releaseLock(); + } + } +} diff --git a/lib/cache/LLMCache.ts b/lib/cache/LLMCache.ts new file mode 100644 index 00000000..578031a0 --- /dev/null +++ b/lib/cache/LLMCache.ts @@ -0,0 +1,48 @@ +import { BaseCache, CacheEntry } from "./BaseCache"; + +export class LLMCache extends BaseCache { + constructor( + logger: (message: { + category?: string; + message: string; + level?: number; + }) => void, + cacheDir?: string, + cacheFile?: string, + ) { + super(logger, cacheDir, cacheFile || "llm_calls.json"); + } + + /** + * Overrides the get method to track used hashes by requestId. + * @param options - The options used to generate the cache key. + * @param requestId - The identifier for the current request. + * @returns The cached data if available, otherwise null. + */ + public async get( + options: Record, + requestId: string, + ): Promise { + const data = await super.get(options, requestId); + return data; + } + + /** + * Overrides the set method to include cache cleanup logic. + * @param options - The options used to generate the cache key. + * @param data - The data to be cached. + * @param requestId - The identifier for the current request. + */ + public async set( + options: Record, + data: any, + requestId: string, + ): Promise { + await super.set(options, data, requestId); + this.logger({ + category: "llm_cache", + message: "Cache miss - saved new response", + level: 1, + }); + } +} diff --git a/lib/dom/global.d.ts b/lib/dom/global.d.ts index e1d9a6db..32af6d11 100644 --- a/lib/dom/global.d.ts +++ b/lib/dom/global.d.ts @@ -4,17 +4,17 @@ declare global { chunkNumber: number; processDom: (chunksSeen: Array) => Promise<{ outputString: string; - selectorMap: Record; + selectorMap: Record; chunk: number; chunks: number[]; }>; processAllOfDom: () => Promise<{ outputString: string; - selectorMap: Record; + selectorMap: Record; }>; processElements: (chunk: number) => Promise<{ outputString: string; - selectorMap: Record; + selectorMap: Record; }>; debugDom: () => Promise; cleanupDebug: () => void; diff --git a/lib/dom/process.ts b/lib/dom/process.ts index 41a7355b..d6407b87 100644 --- a/lib/dom/process.ts +++ b/lib/dom/process.ts @@ -1,6 +1,20 @@ +import { generateXPathsForElement as generateXPaths } from "./xpathUtils"; + +export function isElementNode(node: Node): node is Element { + return node.nodeType === Node.ELEMENT_NODE; +} + +export function isTextNode(node: Node): node is Text { + return node.nodeType === Node.TEXT_NODE && Boolean(node.textContent?.trim()); +} + export async function processDom(chunksSeen: Array) { const { chunk, chunksArray } = await pickChunk(chunksSeen); - const { outputString, selectorMap } = await processElements(chunk); + const { outputString, selectorMap } = await processElements( + chunk, + undefined, + undefined, + ); console.log( `Stagehand (Browser Process): Extracted dom elements:\n${outputString}`, @@ -58,18 +72,24 @@ export async function scrollToHeight(height: number) { scrollEndTimer = window.setTimeout(() => { window.removeEventListener("scroll", handleScrollEnd); resolve(); - }, 200); + }, 100); }; window.addEventListener("scroll", handleScrollEnd, { passive: true }); handleScrollEnd(); }); } + +const xpathCache: Map = new Map(); + export async function processElements( chunk: number, scrollToChunk: boolean = true, indexOffset: number = 0, -) { +): Promise<{ + outputString: string; + selectorMap: Record; +}> { console.time("processElements:total"); const viewportHeight = window.innerHeight; const chunkHeight = viewportHeight * chunk; @@ -89,7 +109,6 @@ export async function processElements( const candidateElements: Array = []; const DOMQueue: Array = [...document.body.childNodes]; - const xpathCache: Map = new Map(); console.log("Stagehand (Browser Process): Generating candidate elements"); console.time("processElements:findCandidates"); @@ -133,7 +152,7 @@ export async function processElements( console.timeEnd("processElements:findCandidates"); - const selectorMap: Record = {}; + const selectorMap: Record = {}; let outputString = ""; console.log( @@ -141,17 +160,28 @@ export async function processElements( ); console.time("processElements:processCandidates"); + console.time("processElements:generateXPaths"); + const xpathLists = await Promise.all( + candidateElements.map(async (element) => { + if (xpathCache.has(element)) { + return xpathCache.get(element); + } + + const xpaths = await generateXPaths(element); + xpathCache.set(element, xpaths); + return xpaths; + }), + ); + console.timeEnd("processElements:generateXPaths"); + candidateElements.forEach((element, index) => { - let xpath = xpathCache.get(element); - if (!xpath) { - xpath = generateXPath(element); - xpathCache.set(element, xpath); - } + const xpaths = xpathLists[index]; + let elementOutput = ""; if (isTextNode(element)) { const textContent = element.textContent?.trim(); if (textContent) { - outputString += `${index + indexOffset}:${textContent}\n`; + elementOutput += `${index + indexOffset}:${textContent}\n`; } } else if (isElementNode(element)) { const tagName = element.tagName.toLowerCase(); @@ -161,10 +191,11 @@ export async function processElements( const closingTag = ``; const textContent = element.textContent?.trim() || ""; - outputString += `${index + indexOffset}:${openingTag}${textContent}${closingTag}\n`; + elementOutput += `${index + indexOffset}:${openingTag}${textContent}${closingTag}\n`; } - selectorMap[index + indexOffset] = xpath; + outputString += elementOutput; + selectorMap[index + indexOffset] = xpaths; }); console.timeEnd("processElements:processCandidates"); @@ -216,48 +247,6 @@ window.processAllOfDom = processAllOfDom; window.processElements = processElements; window.scrollToHeight = scrollToHeight; -function generateXPath(element: ChildNode): string { - if (isElementNode(element) && element.id) { - return `//*[@id='${element.id}']`; - } - - const parts: string[] = []; - while (element && (isTextNode(element) || isElementNode(element))) { - let index = 0; - let hasSameTypeSiblings = false; - const siblings = element.parentElement - ? Array.from(element.parentElement.childNodes) - : []; - - for (let i = 0; i < siblings.length; i++) { - const sibling = siblings[i]; - - if ( - sibling.nodeType === element.nodeType && - sibling.nodeName === element.nodeName - ) { - index = index + 1; - hasSameTypeSiblings = true; - - if (sibling.isSameNode(element)) { - break; - } - } - } - - // text "nodes" are selected differently than elements with xPaths - if (element.nodeName !== "#text") { - const tagName = element.nodeName.toLowerCase(); - const pathIndex = hasSameTypeSiblings ? `[${index}]` : ""; - parts.unshift(`${tagName}${pathIndex}`); - } - - element = element.parentElement as HTMLElement; - } - - return parts.length ? `/${parts.join("/")}` : ""; -} - const leafElementDenyList = ["SVG", "IFRAME", "SCRIPT", "STYLE", "LINK"]; const interactiveElementTypes = [ @@ -301,16 +290,6 @@ const interactiveRoles = [ ]; const interactiveAriaRoles = ["menu", "menuitem", "button"]; -function isElementNode(node: Node): node is Element { - return node.nodeType === Node.ELEMENT_NODE; -} - -function isTextNode(node: Node): node is Text { - // Trim all white space and ensure the text node is non-empty - const trimmedText = node.textContent?.trim().replace(/\s/g, ""); - return node.nodeType === Node.TEXT_NODE && trimmedText !== ""; -} - /* * Checks if an element is visible and therefore relevant for LLMs to consider. We check: * - Size diff --git a/lib/dom/xpathUtils.ts b/lib/dom/xpathUtils.ts new file mode 100644 index 00000000..35a8cf02 --- /dev/null +++ b/lib/dom/xpathUtils.ts @@ -0,0 +1,252 @@ +import { isTextNode } from "./process"; +import { isElementNode } from "./process"; + +function getParentElement(node: ChildNode): Element | null { + return isElementNode(node) + ? node.parentElement + : (node.parentNode as Element); +} + +/** + * Generates all possible combinations of a given array of attributes. + * @param attributes Array of attributes. + * @param size The size of each combination. + * @returns An array of attribute combinations. + */ +function getCombinations( + attributes: { attr: string; value: string }[], + size: number, +): { attr: string; value: string }[][] { + const results: { attr: string; value: string }[][] = []; + + function helper(start: number, combo: { attr: string; value: string }[]) { + if (combo.length === size) { + results.push([...combo]); + return; + } + for (let i = start; i < attributes.length; i++) { + combo.push(attributes[i]); + helper(i + 1, combo); + combo.pop(); + } + } + + helper(0, []); + return results; +} + +/** + * Checks if the generated XPath uniquely identifies the target element. + * @param xpath The XPath string to test. + * @param target The target DOM element. + * @returns True if unique, else false. + */ +function isXPathFirstResultElement(xpath: string, target: Element): boolean { + try { + const result = document.evaluate( + xpath, + document.documentElement, + null, + XPathResult.ORDERED_NODE_SNAPSHOT_TYPE, + null, + ); + return result.snapshotItem(0) === target; + } catch (error) { + // If there's an error evaluating the XPath, consider it not unique + console.warn(`Invalid XPath expression: ${xpath}`, error); + return false; + } +} + +/** + * Escapes a string for use in an XPath expression. + * Handles special characters, including single and double quotes. + * @param value - The string to escape. + * @returns The escaped string safe for XPath. + */ +export function escapeXPathString(value: string): string { + if (value.includes("'")) { + if (value.includes('"')) { + // If the value contains both single and double quotes, split into parts + return ( + "concat(" + + value + .split(/('+)/) + .map((part) => { + if (part === "'") { + return `"'"`; + } else if (part.startsWith("'") && part.endsWith("'")) { + return `"${part}"`; + } else { + return `'${part}'`; + } + }) + .join(",") + + ")" + ); + } else { + // Contains single quotes but not double quotes; use double quotes + return `"${value}"`; + } + } else { + // Does not contain single quotes; use single quotes + return `'${value}'`; + } +} + +/** + * Generates both a complicated XPath and a standard XPath for a given DOM element. + * @param element - The target DOM element. + * @param documentOverride - Optional document override. + * @returns An object containing both XPaths. + */ +export async function generateXPathsForElement( + element: ChildNode, +): Promise { + // Generate the standard XPath + if (!element) return []; + const [complexXPath, standardXPath, idBasedXPath] = await Promise.all([ + generateComplexXPath(element), + generateStandardXPath(element), + generatedIdBasedXPath(element), + ]); + + // This should return in order from most accurate on current page to most cachable. + // Do not change the order if you are not sure what you are doing. + // Contact Navid if you need help understanding it. + return [...(idBasedXPath ? [idBasedXPath] : []), standardXPath, complexXPath]; +} + +async function generateComplexXPath(element: ChildNode): Promise { + // Generate the complicated XPath + const parts: string[] = []; + let currentElement: ChildNode | null = element; + + while ( + currentElement && + (isTextNode(currentElement) || isElementNode(currentElement)) + ) { + if (isElementNode(currentElement)) { + const el = currentElement as Element; + let selector = el.tagName.toLowerCase(); + + // List of attributes to consider for uniqueness + const attributePriority = [ + "data-qa", + "data-component", + "data-role", + "role", + "aria-role", + "type", + "name", + "aria-label", + "placeholder", + "title", + "alt", + ]; + + // Collect attributes present on the element + const attributes = attributePriority + .map((attr) => { + let value = el.getAttribute(attr); + if (attr === "href-full" && value) { + value = el.getAttribute("href"); + } + return value + ? { attr: attr === "href-full" ? "href" : attr, value } + : null; + }) + .filter((attr) => attr !== null) as { attr: string; value: string }[]; + + // Attempt to find a combination of attributes that uniquely identifies the element + let uniqueSelector = ""; + for (let i = 1; i <= attributes.length; i++) { + const combinations = getCombinations(attributes, i); + for (const combo of combinations) { + const conditions = combo + .map((a) => `@${a.attr}=${escapeXPathString(a.value)}`) + .join(" and "); + const xpath = `//${selector}[${conditions}]`; + if (isXPathFirstResultElement(xpath, el)) { + uniqueSelector = xpath; + break; + } + } + if (uniqueSelector) break; + } + + if (uniqueSelector) { + parts.unshift(uniqueSelector.replace("//", "")); + break; + } else { + // Fallback to positional selector + const parent = getParentElement(el); + if (parent) { + const siblings = Array.from(parent.children).filter( + (sibling) => sibling.tagName === el.tagName, + ); + const index = siblings.indexOf(el as HTMLElement) + 1; + selector += siblings.length > 1 ? `[${index}]` : ""; + } + parts.unshift(selector); + } + } + + currentElement = getParentElement(currentElement); + } + + const xpath = "//" + parts.join("/"); + return xpath; +} + +/** + * Generates a standard XPath for a given DOM element. + * @param element - The target DOM element. + * @returns A standard XPath string. + */ +async function generateStandardXPath(element: ChildNode): Promise { + const parts: string[] = []; + while (element && (isTextNode(element) || isElementNode(element))) { + let index = 0; + let hasSameTypeSiblings = false; + const siblings = element.parentElement + ? Array.from(element.parentElement.childNodes) + : []; + + for (let i = 0; i < siblings.length; i++) { + const sibling = siblings[i]; + + if ( + sibling.nodeType === element.nodeType && + sibling.nodeName === element.nodeName + ) { + index = index + 1; + hasSameTypeSiblings = true; + + if (sibling.isSameNode(element)) { + break; + } + } + } + + // text "nodes" are selected differently than elements with xPaths + if (element.nodeName !== "#text") { + const tagName = element.nodeName.toLowerCase(); + const pathIndex = hasSameTypeSiblings ? `[${index}]` : ""; + parts.unshift(`${tagName}${pathIndex}`); + } + + element = element.parentElement as HTMLElement; + } + + return parts.length ? `//${parts.join("//")}` : ""; +} + +async function generatedIdBasedXPath( + element: ChildNode, +): Promise { + if (isElementNode(element) && element.id) { + return `//*[@id='${element.id}']`; + } + return null; +} diff --git a/lib/handlers/actHandler.ts b/lib/handlers/actHandler.ts new file mode 100644 index 00000000..074c4695 --- /dev/null +++ b/lib/handlers/actHandler.ts @@ -0,0 +1,1071 @@ +import { Stagehand } from "../index"; +import { AvailableModel, LLMProvider } from "../llm/LLMProvider"; +import { ScreenshotService } from "../vision"; +import { verifyActCompletion, act, fillInVariables } from "../inference"; +import { + PlaywrightCommandException, + PlaywrightCommandMethodNotSupportedException, +} from "../types"; +import { Locator, Page } from "@playwright/test"; +import { ActionCache } from "../cache/ActionCache"; +import { modelsWithVision } from "../llm/LLMClient"; +import { generateId } from "../utils"; + +export class StagehandActHandler { + private readonly stagehand: Stagehand; + private readonly verbose: 0 | 1 | 2; + private readonly llmProvider: LLMProvider; + private readonly enableCaching: boolean; + private readonly logger: (log: { + category: string; + message: string; + level: 0 | 1 | 2; + }) => void; + private readonly waitForSettledDom: ( + domSettleTimeoutMs?: number, + ) => Promise; + private readonly actionCache: ActionCache; + private readonly defaultModelName: AvailableModel; + private readonly startDomDebug: () => Promise; + private readonly cleanupDomDebug: () => Promise; + private actions: { [key: string]: { result: string; action: string } }; + + constructor({ + stagehand, + verbose, + llmProvider, + enableCaching, + logger, + waitForSettledDom, + defaultModelName, + startDomDebug, + cleanupDomDebug, + }: { + stagehand: Stagehand; + verbose: 0 | 1 | 2; + llmProvider: LLMProvider; + enableCaching: boolean; + logger: (log: { + category: string; + message: string; + level: 0 | 1 | 2; + }) => void; + waitForSettledDom: (domSettleTimeoutMs?: number) => Promise; + defaultModelName: AvailableModel; + startDomDebug: () => Promise; + cleanupDomDebug: () => Promise; + }) { + this.stagehand = stagehand; + this.verbose = verbose; + this.llmProvider = llmProvider; + this.enableCaching = enableCaching; + this.logger = logger; + this.waitForSettledDom = waitForSettledDom; + this.actionCache = new ActionCache(this.logger); + this.defaultModelName = defaultModelName; + this.startDomDebug = startDomDebug; + this.cleanupDomDebug = cleanupDomDebug; + this.actions = {}; + } + + private async _recordAction(action: string, result: string): Promise { + const id = generateId(action); + + this.actions[id] = { result, action }; + + return id; + } + + private async _verifyActionCompletion({ + completed, + verifierUseVision, + requestId, + action, + steps, + model, + domSettleTimeoutMs, + }: { + completed: boolean; + verifierUseVision: boolean; + requestId: string; + action: string; + steps: string; + model: AvailableModel; + domSettleTimeoutMs?: number; + }): Promise { + await this.waitForSettledDom(domSettleTimeoutMs); + + const { selectorMap } = await this.stagehand.page.evaluate(() => { + return window.processAllOfDom(); + }); + + let actionCompleted = false; + if (completed) { + // Run action completion verifier + this.stagehand.log({ + category: "action", + message: `Action marked as completed, Verifying if this is true...`, + level: 1, + }); + + let domElements: string | undefined = undefined; + let fullpageScreenshot: Buffer | undefined = undefined; + + if (verifierUseVision) { + try { + const screenshotService = new ScreenshotService( + this.stagehand.page, + selectorMap, + this.verbose, + ); + + fullpageScreenshot = await screenshotService.getScreenshot(true, 15); + } catch (e) { + this.stagehand.log({ + category: "action", + message: `Error getting full page screenshot: ${e.message}\n. Trying again...`, + level: 1, + }); + + const screenshotService = new ScreenshotService( + this.stagehand.page, + selectorMap, + this.verbose, + ); + + fullpageScreenshot = await screenshotService.getScreenshot(true, 15); + } + } else { + ({ outputString: domElements } = await this.stagehand.page.evaluate( + () => { + return window.processAllOfDom(); + }, + )); + } + + actionCompleted = await verifyActCompletion({ + goal: action, + steps, + llmProvider: this.llmProvider, + modelName: model, + screenshot: fullpageScreenshot, + domElements, + logger: this.logger, + requestId, + }); + + this.stagehand.log({ + category: "action", + message: `Action completion verification result: ${actionCompleted}`, + level: 1, + }); + } + + return actionCompleted; + } + + private async _performPlaywrightMethod( + method: string, + args: string[], + xpath: string, + domSettleTimeoutMs?: number, + ) { + const locator = this.stagehand.page.locator(`xpath=${xpath}`).first(); + const initialUrl = this.stagehand.page.url(); + if (method === "scrollIntoView") { + this.stagehand.log({ + category: "action", + message: `Scrolling element into view`, + level: 2, + }); + try { + await locator + .evaluate((element: any) => { + element.scrollIntoView({ behavior: "smooth", block: "center" }); + }) + .catch((e: Error) => { + this.stagehand.log({ + category: "action", + message: `Error scrolling element into view: ${e.message}\nTrace: ${e.stack}`, + level: 1, + }); + }); + } catch (e) { + this.stagehand.log({ + category: "action", + message: `Error scrolling element into view: ${e.message}\nTrace: ${e.stack}`, + level: 1, + }); + + throw new PlaywrightCommandException(e.message); + } + } else if (method === "fill" || method === "type") { + try { + await locator.fill(""); + await locator.click(); + const text = args[0]; + for (const char of text) { + await this.stagehand.page.keyboard.type(char, { + delay: Math.random() * 50 + 25, + }); + } + } catch (e) { + this.logger({ + category: "action", + message: `Error filling element: ${e.message}\nTrace: ${e.stack}`, + level: 1, + }); + + throw new PlaywrightCommandException(e.message); + } + } else if (method === "press") { + try { + const key = args[0]; + await this.stagehand.page.keyboard.press(key); + } catch (e) { + this.logger({ + category: "action", + message: `Error pressing key: ${e.message}\nTrace: ${e.stack}`, + level: 1, + }); + + throw new PlaywrightCommandException(e.message); + } + } else if (typeof locator[method as keyof typeof locator] === "function") { + // Log current URL before action + this.logger({ + category: "action", + message: `Page URL before action: ${this.stagehand.page.url()}`, + level: 2, + }); + + // Perform the action + try { + // @ts-ignore + await locator[method](...args); + } catch (e) { + this.logger({ + category: "action", + message: `Error performing method ${method} with args ${JSON.stringify( + args, + )}: ${e.message}\nTrace: ${e.stack}`, + level: 1, + }); + + throw new PlaywrightCommandException(e.message); + } + + // Handle navigation if a new page is opened + if (method === "click") { + this.logger({ + category: "action", + message: `Clicking element, checking for page navigation`, + level: 1, + }); + + // NAVIDNOTE: Should this happen before we wait for locator[method]? + const newOpenedTab = await Promise.race([ + new Promise((resolve) => { + this.stagehand.context.once("page", (page) => resolve(page)); + setTimeout(() => resolve(null), 1_500); + }), + ]); + + this.logger({ + category: "action", + message: `Clicked element, ${ + newOpenedTab ? "opened a new tab" : "no new tabs opened" + }`, + level: 1, + }); + + if (newOpenedTab) { + this.logger({ + category: "action", + message: `New page detected (new tab) with URL: ${newOpenedTab.url()}`, + level: 1, + }); + await newOpenedTab.close(); + await this.stagehand.page.goto(newOpenedTab.url()); + await this.stagehand.page.waitForLoadState("domcontentloaded"); + await this.waitForSettledDom(domSettleTimeoutMs); + } + + // Wait for the network to be idle with timeout of 5s (will only wait if loading a new page) + // await this.waitForSettledDom(domSettleTimeoutMs); + await Promise.race([ + this.stagehand.page.waitForLoadState("networkidle"), + new Promise((resolve) => setTimeout(resolve, 5_000)), + ]).catch((e: Error) => { + this.logger({ + category: "action", + message: `Network idle timeout hit`, + level: 1, + }); + }); + + this.logger({ + category: "action", + message: `Finished waiting for (possible) page navigation`, + level: 1, + }); + + if (this.stagehand.page.url() !== initialUrl) { + this.logger({ + category: "action", + message: `New page detected with URL: ${this.stagehand.page.url()}`, + level: 1, + }); + } + } + } else { + this.logger({ + category: "action", + message: `Chosen method ${method} is invalid`, + level: 1, + }); + + throw new PlaywrightCommandMethodNotSupportedException( + `Method ${method} not supported`, + ); + } + + await this.waitForSettledDom(domSettleTimeoutMs); + } + + private async _getComponentString(locator: Locator) { + return await locator.evaluate((el) => { + // Create a clone of the element to avoid modifying the original + const clone = el.cloneNode(true) as HTMLElement; + + // Keep only specific stable attributes that help identify elements + const attributesToKeep = [ + "type", + "name", + "placeholder", + "aria-label", + "role", + "href", + "title", + "alt", + ]; + + // Remove all attributes except those we want to keep + Array.from(clone.attributes).forEach((attr) => { + if (!attributesToKeep.includes(attr.name)) { + clone.removeAttribute(attr.name); + } + }); + + const outerHtml = clone.outerHTML; + + // const variables = { + // // Replace with your actual variables and their values + // // Example: + // username: "JohnDoe", + // email: "john@example.com", + // }; + + // // Function to replace variable values with variable names + // const replaceVariables = (element: Element) => { + // if (element instanceof HTMLElement) { + // for (const [key, value] of Object.entries(variables)) { + // if (value) { + // element.innerText = element.innerText.replace( + // new RegExp(value, "g"), + // key, + // ); + // } + // } + // } + + // if ( + // element instanceof HTMLInputElement || + // element instanceof HTMLTextAreaElement + // ) { + // for (const [key, value] of Object.entries(variables)) { + // if (value) { + // element.value = element.value.replace( + // new RegExp(value, "g"), + // key, + // ); + // } + // } + // } + // }; + + // // Replace variables in the cloned element + // replaceVariables(clone); + + // // Replace variables in all child elements + // clone.querySelectorAll("*").forEach(replaceVariables); + return outerHtml.trim().replace(/\s+/g, " "); + }); + } + + private async getElement( + xpath: string, + timeout: number = 5_000, + ): Promise { + try { + const element = this.stagehand.page.locator(`xpath=${xpath}`).first(); + await element.waitFor({ state: "attached", timeout }); + return element; + } catch { + this.logger({ + category: "action", + message: `Element with XPath ${xpath} not found within ${timeout}ms.`, + level: 1, + }); + return null; + } + } + + private async _checkIfCachedStepIsValid_oneXpath(cachedStep: { + xpath: string; + savedComponentString: string; + }) { + this.logger({ + category: "action", + message: `Checking if cached step is valid: ${cachedStep.xpath}, ${cachedStep.savedComponentString}`, + level: 1, + }); + try { + const locator = await this.getElement(cachedStep.xpath); + if (!locator) { + this.logger({ + category: "action", + message: `Locator not found for xpath: ${cachedStep.xpath}`, + level: 1, + }); + return false; + } + + this.logger({ + category: "action", + message: `locator element: ${await this._getComponentString(locator)}`, + level: 1, + }); + + // First try to get the value (for input/textarea elements) + let currentComponent = await this._getComponentString(locator); + + this.logger({ + category: "action", + message: `Current text: ${currentComponent}`, + level: 1, + }); + + if (!currentComponent || !cachedStep.savedComponentString) { + this.logger({ + category: "action", + message: `Current text or cached text is undefined`, + level: 1, + }); + return false; + } + + // Normalize whitespace and trim both strings before comparing + const normalizedCurrentText = currentComponent + .trim() + .replace(/\s+/g, " "); + const normalizedCachedText = cachedStep.savedComponentString + .trim() + .replace(/\s+/g, " "); + + if (normalizedCurrentText !== normalizedCachedText) { + this.logger({ + category: "action", + message: `Current text and cached text do not match: ${normalizedCurrentText} !== ${normalizedCachedText}`, + level: 1, + }); + return false; + } + + return true; + } catch (e) { + this.logger({ + category: "action", + message: `Error checking if cached step is valid: ${e.message}\nTrace: ${e.stack}`, + level: 1, + }); + return false; // Added explicit return false for error cases + } + } + + private async _getValidCachedStepXpath(cachedStep: { + xpaths: string[]; + savedComponentString: string; + }) { + const reversedXpaths = [...cachedStep.xpaths].reverse(); // We reverse the xpaths to try the most cachable ones first + for (const xpath of reversedXpaths) { + const isValid = await this._checkIfCachedStepIsValid_oneXpath({ + xpath, + savedComponentString: cachedStep.savedComponentString, + }); + + if (isValid) { + return xpath; + } + } + return null; + } + + private async _runCachedActionIfAvailable({ + action, + previousSelectors, + requestId, + steps, + chunksSeen, + modelName, + useVision, + verifierUseVision, + retries, + variables, + model, + domSettleTimeoutMs, + }: { + action: string; + previousSelectors: string[]; + requestId: string; + steps: string; + chunksSeen: number[]; + modelName: AvailableModel; + useVision: boolean | "fallback"; + verifierUseVision: boolean; + retries: number; + variables: Record; + model: AvailableModel; + domSettleTimeoutMs?: number; + }) { + const cacheObj = { + url: this.stagehand.page.url(), + action, + previousSelectors, + requestId, + }; + + this.logger({ + category: "action", + message: `Checking action cache for: ${JSON.stringify(cacheObj)}`, + level: 1, + }); + + const cachedStep = await this.actionCache.getActionStep(cacheObj); + + if (!cachedStep) { + this.logger({ + category: "action", + message: `Action cache miss: ${JSON.stringify(cacheObj)}`, + level: 1, + }); + return null; + } + + this.logger({ + category: "action", + message: `Action cache semi-hit: ${cachedStep.playwrightCommand.method} with args: ${JSON.stringify( + cachedStep.playwrightCommand.args, + )}`, + level: 1, + }); + + try { + const validXpath = await this._getValidCachedStepXpath({ + xpaths: cachedStep.xpaths, + savedComponentString: cachedStep.componentString, + }); + + this.logger({ + category: "action", + message: `Cached action step is valid: ${validXpath !== null}`, + level: 1, + }); + + if (!validXpath) { + this.logger({ + category: "action", + message: `Cached action step is invalid, removing...`, + level: 1, + }); + + await this.actionCache.removeActionStep(cacheObj); + return null; + } + + this.logger({ + category: "action", + message: `Action Cache Hit: ${cachedStep.playwrightCommand.method} with args: ${JSON.stringify( + cachedStep.playwrightCommand.args, + )}`, + level: 1, + }); + + cachedStep.playwrightCommand.args = cachedStep.playwrightCommand.args.map( + (arg) => { + return fillInVariables(arg, variables); + }, + ); + + await this._performPlaywrightMethod( + cachedStep.playwrightCommand.method, + cachedStep.playwrightCommand.args, + validXpath, + domSettleTimeoutMs, + ); + + steps = steps + cachedStep.newStepString; + const { outputString, selectorMap } = await this.stagehand.page.evaluate( + ({ chunksSeen }: { chunksSeen: number[] }) => { + // @ts-ignore + return window.processDom(chunksSeen); + }, + { chunksSeen }, + ); + + if (cachedStep.completed) { + // Verify the action was completed successfully + let actionCompleted = await this._verifyActionCompletion({ + completed: true, + verifierUseVision, + model, + steps, + requestId, + action, + domSettleTimeoutMs, + }); + + this.logger({ + category: "action", + message: `Action completion verification result from cache: ${actionCompleted}`, + level: 1, + }); + + if (actionCompleted) { + return { + success: true, + message: "Action completed successfully using cached step", + action, + }; + } + } + + return this.act({ + action, + steps, + chunksSeen, + modelName, + useVision, + verifierUseVision, + retries, + requestId, + variables, + previousSelectors: [...previousSelectors, cachedStep.xpaths[0]], + skipActionCacheForThisStep: false, + domSettleTimeoutMs, + }); + } catch (exception) { + this.logger({ + category: "action", + message: `Error performing cached action step: ${exception.message}\nTrace: ${exception.stack}`, + level: 1, + }); + + await this.actionCache.removeActionStep(cacheObj); + return null; + } + } + + public async act({ + action, + steps = "", + chunksSeen, + modelName, + useVision, + verifierUseVision, + retries = 0, + requestId, + variables, + previousSelectors, + skipActionCacheForThisStep = false, + domSettleTimeoutMs, + }: { + action: string; + steps?: string; + chunksSeen: number[]; + modelName?: AvailableModel; + useVision: boolean | "fallback"; + verifierUseVision: boolean; + retries?: number; + requestId?: string; + variables: Record; + previousSelectors: string[]; + skipActionCacheForThisStep: boolean; + domSettleTimeoutMs?: number; + }): Promise<{ success: boolean; message: string; action: string }> { + try { + await this.waitForSettledDom(domSettleTimeoutMs); + + await this.startDomDebug(); + + const model = modelName ?? this.defaultModelName; + + if (this.enableCaching && !skipActionCacheForThisStep) { + const response = await this._runCachedActionIfAvailable({ + action, + previousSelectors, + requestId, + steps, + chunksSeen, + modelName: model, + useVision, + verifierUseVision, + retries, + variables, + model, + domSettleTimeoutMs, + }); + + if (response !== null) { + return response; + } else { + return this.act({ + action, + steps, + chunksSeen, + modelName, + useVision, + verifierUseVision, + retries, + requestId, + variables, + previousSelectors, + skipActionCacheForThisStep: true, + domSettleTimeoutMs, + }); + } + } + + if ( + !modelsWithVision.includes(model) && + (useVision !== false || verifierUseVision) + ) { + this.logger({ + category: "action", + message: `${model} does not support vision, but useVision was set to ${useVision}. Defaulting to false.`, + level: 1, + }); + useVision = false; + verifierUseVision = false; + } + + this.logger({ + category: "action", + message: `Running / Continuing action: ${action} on page: ${this.stagehand.page.url()}`, + level: 2, + }); + + this.logger({ + category: "action", + message: `Processing DOM...`, + level: 2, + }); + + const { outputString, selectorMap, chunk, chunks } = + await this.stagehand.page.evaluate( + ({ chunksSeen }: { chunksSeen: number[] }) => { + // @ts-ignore + return window.processDom(chunksSeen); + }, + { chunksSeen }, + ); + + this.logger({ + category: "action", + message: `Looking at chunk ${chunk}. Chunks left: ${ + chunks.length - chunksSeen.length + }`, + level: 1, + }); + + // Prepare annotated screenshot if vision is enabled + let annotatedScreenshot: Buffer | undefined; + if (useVision === true) { + if (!modelsWithVision.includes(model)) { + this.logger({ + category: "action", + message: `${model} does not support vision. Skipping vision processing.`, + level: 1, + }); + } else { + const screenshotService = new ScreenshotService( + this.stagehand.page, + selectorMap, + this.verbose, + ); + + annotatedScreenshot = + await screenshotService.getAnnotatedScreenshot(false); + } + } + + const response = await act({ + action, + domElements: outputString, + steps, + llmProvider: this.llmProvider, + modelName: model, + screenshot: annotatedScreenshot, + logger: this.logger, + requestId, + variables, + }); + + this.logger({ + category: "action", + message: `Received response from LLM: ${JSON.stringify(response)}`, + level: 1, + }); + + await this.cleanupDomDebug(); + + if (!response) { + if (chunksSeen.length + 1 < chunks.length) { + chunksSeen.push(chunk); + + this.logger({ + category: "action", + message: `No action found in current chunk. Chunks seen: ${chunksSeen.length}.`, + level: 1, + }); + + return this.act({ + action, + steps: + steps + + (!steps.endsWith("\n") ? "\n" : "") + + "## Step: Scrolled to another section\n", + chunksSeen, + modelName, + useVision, + verifierUseVision, + requestId, + variables, + previousSelectors, + skipActionCacheForThisStep, + domSettleTimeoutMs, + }); + } else if (useVision === "fallback") { + this.logger({ + category: "action", + message: `Switching to vision-based processing`, + level: 1, + }); + await this.stagehand.page.evaluate(() => window.scrollToHeight(0)); + return await this.act({ + action, + steps, + chunksSeen, + modelName, + useVision: true, + verifierUseVision, + requestId, + variables, + previousSelectors, + skipActionCacheForThisStep, + domSettleTimeoutMs, + }); + } else { + if (this.enableCaching) { + this.llmProvider.cleanRequestCache(requestId); + this.actionCache.deleteCacheForRequestId(requestId); + } + + return { + success: false, + message: `Action was not able to be completed.`, + action: action, + }; + } + } + + // Action found, proceed to execute + const elementId = response["element"]; + const xpaths = selectorMap[elementId]; + const method = response["method"]; + const args = response["args"]; + + // Get the element text from the outputString + const elementLines = outputString.split("\n"); + const elementText = + elementLines + .find((line) => line.startsWith(`${elementId}:`)) + ?.split(":")[1] || "Element not found"; + + this.logger({ + category: "action", + message: `Executing method: ${method} on element: ${elementId} (xpaths: ${xpaths.join( + ", ", + )}) with args: ${JSON.stringify(args)}`, + level: 1, + }); + + try { + const initialUrl = this.stagehand.page.url(); + const locator = this.stagehand.page + .locator(`xpath=${xpaths[0]}`) + .first(); + const originalUrl = this.stagehand.page.url(); + const componentString = await this._getComponentString(locator); + const responseArgs = [...args]; + if (variables) { + responseArgs.forEach((arg, index) => { + if (typeof arg === "string") { + args[index] = fillInVariables(arg, variables); + } + }); + } + await this._performPlaywrightMethod( + method, + args, + xpaths[0], + domSettleTimeoutMs, + ); + + const newStepString = + (!steps.endsWith("\n") ? "\n" : "") + + `## Step: ${response.step}\n` + + ` Element: ${elementText}\n` + + ` Action: ${response.method}\n` + + ` Reasoning: ${response.why}\n`; + + steps += newStepString; + + if (this.enableCaching) { + this.actionCache + .addActionStep({ + action, + url: originalUrl, + previousSelectors, + playwrightCommand: { + method, + args: responseArgs, + }, + componentString, + requestId, + xpaths: xpaths, + newStepString, + completed: response.completed, + }) + .catch((e) => { + this.logger({ + category: "action", + message: `Error adding action step to cache: ${e.message}\nTrace: ${e.stack}`, + level: 1, + }); + }); + } + + if (this.stagehand.page.url() !== initialUrl) { + steps += ` Result (Important): Page URL changed from ${initialUrl} to ${this.stagehand.page.url()}\n\n`; + } + + const actionCompleted = await this._verifyActionCompletion({ + completed: response.completed, + verifierUseVision, + requestId, + action, + steps, + model, + domSettleTimeoutMs, + }); + + if (!actionCompleted) { + this.logger({ + category: "action", + message: `Continuing to next action step`, + level: 1, + }); + + return this.act({ + action, + steps, + modelName, + chunksSeen, + useVision, + verifierUseVision, + requestId, + variables, + previousSelectors: [...previousSelectors, xpaths[0]], + skipActionCacheForThisStep: false, + domSettleTimeoutMs, + }); + } else { + this.logger({ + category: "action", + message: `Action completed successfully`, + level: 1, + }); + await this._recordAction(action, response.step); + return { + success: true, + message: `Action completed successfully: ${steps}${response.step}`, + action: action, + }; + } + } catch (error) { + this.logger({ + category: "action", + message: `Error performing action - D (Retries: ${retries}): ${error.message}\nTrace: ${error.stack}`, + level: 1, + }); + + if (retries < 2) { + return this.act({ + action, + steps, + modelName, + useVision, + verifierUseVision, + retries: retries + 1, + chunksSeen, + requestId, + variables, + previousSelectors, + skipActionCacheForThisStep, + domSettleTimeoutMs, + }); + } + + await this._recordAction(action, ""); + if (this.enableCaching) { + this.llmProvider.cleanRequestCache(requestId); + this.actionCache.deleteCacheForRequestId(requestId); + } + + return { + success: false, + message: `Error performing action - A: ${error.message}`, + action: action, + }; + } + } catch (error) { + this.logger({ + category: "action", + message: `Error performing action - B: ${error.message}\nTrace: ${error.stack}`, + level: 1, + }); + + if (this.enableCaching) { + this.llmProvider.cleanRequestCache(requestId); + this.actionCache.deleteCacheForRequestId(requestId); + } + + return { + success: false, + message: `Error performing action - C: ${error.message}`, + action: action, + }; + } + } +} diff --git a/lib/index.ts b/lib/index.ts index f185bfe5..eea3f954 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -1,13 +1,15 @@ import { type Page, type BrowserContext, chromium } from "@playwright/test"; -import crypto from "crypto"; import { z } from "zod"; import fs from "fs"; import { Browserbase } from "@browserbasehq/sdk"; -import { act, extract, observe, verifyActCompletion } from "./inference"; +import { extract, observe } from "./inference"; import { AvailableModel, LLMProvider } from "./llm/LLMProvider"; import path from "path"; import { ScreenshotService } from "./vision"; import { modelsWithVision } from "./llm/LLMClient"; +import { ActionCache } from "./cache/ActionCache"; +import { StagehandActHandler } from "./handlers/actHandler"; +import { generateId } from "./utils"; require("dotenv").config({ path: ".env" }); @@ -228,7 +230,6 @@ export class Stagehand { instruction: string; }; }; - private actions: { [key: string]: { result: string; action: string } }; public page: Page; public context: BrowserContext; private env: "LOCAL" | "BROWSERBASE"; @@ -246,6 +247,8 @@ export class Stagehand { private domSettleTimeoutMs: number; private browserBaseSessionCreateParams?: Browserbase.Sessions.SessionCreateParams; private enableCaching: boolean; + private variables: { [key: string]: any }; + private actHandler: StagehandActHandler; private browserbaseResumeSessionID?: string; constructor( @@ -290,21 +293,35 @@ export class Stagehand { llmProvider || new LLMProvider(this.logger, this.enableCaching); this.env = env; this.observations = {}; - this.apiKey = apiKey; - this.projectId = projectId; - this.actions = {}; + this.apiKey = apiKey || process.env.BROWSERBASE_API_KEY; + this.projectId = projectId || process.env.BROWSERBASE_PROJECT_ID; this.verbose = verbose ?? 0; this.debugDom = debugDom ?? false; this.defaultModelName = "gpt-4o"; this.domSettleTimeoutMs = domSettleTimeoutMs ?? 30_000; this.headless = headless ?? false; this.browserBaseSessionCreateParams = browserBaseSessionCreateParams; + this.actHandler = new StagehandActHandler({ + stagehand: this, + verbose: this.verbose, + llmProvider: this.llmProvider, + enableCaching: this.enableCaching, + logger: this.logger, + waitForSettledDom: this._waitForSettledDom.bind(this), + defaultModelName: this.defaultModelName, + startDomDebug: this.startDomDebug.bind(this), + cleanupDomDebug: this.cleanupDomDebug.bind(this), + }); this.browserbaseResumeSessionID = browserbaseResumeSessionID; } async init({ modelName = "gpt-4o", - }: { modelName?: AvailableModel } = {}): Promise<{ + domSettleTimeoutMs, + }: { + modelName?: AvailableModel; + domSettleTimeoutMs?: number; + } = {}): Promise<{ debugUrl: string; sessionUrl: string; }> { @@ -323,6 +340,7 @@ export class Stagehand { this.context = context; this.page = context.pages()[0]; this.defaultModelName = modelName; + this.domSettleTimeoutMs = domSettleTimeoutMs ?? this.domSettleTimeoutMs; // Overload the page.goto method const originalGoto = this.page.goto.bind(this.page); @@ -377,6 +395,10 @@ export class Stagehand { } // Add initialization scripts + await this.page.addInitScript({ + path: path.join(__dirname, "..", "dist", "dom", "build", "xpathUtils.js"), + }); + await this.page.addInitScript({ path: path.join(__dirname, "..", "dist", "dom", "build", "process.js"), }); @@ -558,30 +580,17 @@ export class Stagehand { } } - // Recording - private _generateId(operation: string) { - return crypto.createHash("sha256").update(operation).digest("hex"); - } - private async _recordObservation( instruction: string, result: { selector: string; description: string }[], ): Promise { - const id = this._generateId(instruction); + const id = generateId(instruction); this.observations[id] = { result, instruction }; return id; } - private async _recordAction(action: string, result: string): Promise { - const id = this._generateId(action); - - this.actions[id] = { result, action }; - - return id; - } - // Main methods private async _extract({ @@ -747,7 +756,7 @@ export class Stagehand { return { ...rest, - selector: `xpath=${selectorMap[elementId]}`, + selector: `xpath=${selectorMap[elementId][0]}`, }; }, ); @@ -766,569 +775,17 @@ export class Stagehand { return elementsWithSelectors; } - private async _act({ - action, - steps = "", - chunksSeen, - modelName, - useVision, - verifierUseVision, - retries = 0, - requestId, - domSettleTimeoutMs, - }: { - action: string; - steps?: string; - chunksSeen: number[]; - modelName?: AvailableModel; - useVision: boolean | "fallback"; - verifierUseVision: boolean; - retries?: number; - requestId?: string; - domSettleTimeoutMs?: number; - }): Promise<{ success: boolean; message: string; action: string }> { - const model = modelName ?? this.defaultModelName; - - if ( - !modelsWithVision.includes(model) && - (useVision !== false || verifierUseVision) - ) { - this.log({ - category: "action", - message: `${model} does not support vision, but useVision was set to ${useVision}. Defaulting to false.`, - level: 1, - }); - useVision = false; - verifierUseVision = false; - } - - this.log({ - category: "action", - message: `Running / Continuing action: ${action} on page: ${this.page.url()}`, - level: 2, - }); - - await this._waitForSettledDom(domSettleTimeoutMs); - - await this.startDomDebug(); - - this.log({ - category: "action", - message: `Processing DOM...`, - level: 2, - }); - - const { outputString, selectorMap, chunk, chunks } = - await this.page.evaluate( - ({ chunksSeen }: { chunksSeen: number[] }) => { - // @ts-ignore - return window.processDom(chunksSeen); - }, - { chunksSeen }, - ); - - this.log({ - category: "action", - message: `Looking at chunk ${chunk}. Chunks left: ${ - chunks.length - chunksSeen.length - }`, - level: 1, - }); - - // Prepare annotated screenshot if vision is enabled - let annotatedScreenshot: Buffer | undefined; - if (useVision === true) { - if (!modelsWithVision.includes(model)) { - this.log({ - category: "action", - message: `${model} does not support vision. Skipping vision processing.`, - level: 1, - }); - } else { - const screenshotService = new ScreenshotService( - this.page, - selectorMap, - this.verbose, - ); - - annotatedScreenshot = - await screenshotService.getAnnotatedScreenshot(false); - } - } - - const response = await act({ - action, - domElements: outputString, - steps, - llmProvider: this.llmProvider, - modelName: model, - screenshot: annotatedScreenshot, - logger: this.logger, - requestId, - }); - - this.log({ - category: "action", - message: `Received response from LLM: ${JSON.stringify(response)}`, - level: 1, - }); - - await this.cleanupDomDebug(); - - if (!response) { - if (chunksSeen.length + 1 < chunks.length) { - chunksSeen.push(chunk); - - this.log({ - category: "action", - message: `No action found in current chunk. Chunks seen: ${chunksSeen.length}.`, - level: 1, - }); - - return this._act({ - action, - steps: - steps + - (!steps.endsWith("\n") ? "\n" : "") + - "## Step: Scrolled to another section\n", - chunksSeen, - modelName, - useVision, - verifierUseVision, - requestId, - domSettleTimeoutMs, - }); - } else if (useVision === "fallback") { - this.log({ - category: "action", - message: `Switching to vision-based processing`, - level: 1, - }); - await this.page.evaluate(() => window.scrollToHeight(0)); - return await this._act({ - action, - steps, - chunksSeen, - modelName, - useVision: true, - verifierUseVision, - requestId, - domSettleTimeoutMs, - }); - } else { - if (this.enableCaching) { - this.llmProvider.cleanRequestCache(requestId); - } - - return { - success: false, - message: `Action was not able to be completed.`, - action: action, - }; - } - } - - // Action found, proceed to execute - const elementId = response["element"]; - const xpath = selectorMap[elementId]; - const method = response["method"]; - const args = response["args"]; - - // Get the element text from the outputString - const elementLines = outputString.split("\n"); - const elementText = - elementLines - .find((line) => line.startsWith(`${elementId}:`)) - ?.split(":")[1] || "Element not found"; - - this.log({ - category: "action", - message: `Executing method: ${method} on element: ${elementId} (xpath: ${xpath}) with args: ${JSON.stringify( - args, - )}`, - level: 1, - }); - - let urlChangeString = ""; - - const locator = this.page.locator(`xpath=${xpath}`).first(); - try { - const initialUrl = this.page.url(); - if (method === "scrollIntoView") { - this.log({ - category: "action", - message: `Scrolling element into view`, - level: 2, - }); - try { - await locator - .evaluate((element) => { - element.scrollIntoView({ behavior: "smooth", block: "center" }); - }) - .catch((e: Error) => { - this.log({ - category: "action", - message: `Error scrolling element into view: ${e.message}\nTrace: ${e.stack}`, - level: 1, - }); - }); - } catch (e) { - this.log({ - category: "action", - message: `Error scrolling element into view (Retries ${retries}): ${e.message}\nTrace: ${e.stack}`, - level: 1, - }); - - if (retries < 2) { - return this._act({ - action, - steps, - modelName, - useVision, - verifierUseVision, - retries: retries + 1, - chunksSeen, - requestId, - domSettleTimeoutMs, - }); - } - } - } else if (method === "fill" || method === "type") { - try { - await locator.fill(""); - await locator.click(); - const text = args[0]; - for (const char of text) { - await this.page.keyboard.type(char, { - delay: Math.random() * 50 + 25, - }); - } - } catch (e) { - this.log({ - category: "action", - message: `Error filling element (Retries ${retries}): ${e.message}\nTrace: ${e.stack}`, - level: 1, - }); - - if (retries < 2) { - return this._act({ - action, - steps, - modelName, - useVision, - verifierUseVision, - retries: retries + 1, - chunksSeen, - requestId, - domSettleTimeoutMs, - }); - } - } - } else if (method === "press") { - try { - const key = args[0]; - await this.page.keyboard.press(key); - } catch (e) { - this.log({ - category: "action", - message: `Error pressing key (Retries ${retries}): ${e.message}\nTrace: ${e.stack}`, - level: 1, - }); - - if (retries < 2) { - return this._act({ - action, - steps, - modelName, - useVision, - verifierUseVision, - retries: retries + 1, - chunksSeen, - requestId, - domSettleTimeoutMs, - }); - } - } - } else if ( - typeof locator[method as keyof typeof locator] === "function" - ) { - // Log current URL before action - this.log({ - category: "action", - message: `Page URL before action: ${this.page.url()}`, - level: 2, - }); - - // Perform the action - try { - // @ts-ignore - await locator[method](...args); - } catch (e) { - this.log({ - category: "action", - message: `Error performing method ${method} with args ${JSON.stringify( - args, - )} (Retries: ${retries}): ${e.message}\nTrace: ${e.stack}`, - level: 1, - }); - - if (retries < 2) { - return this._act({ - action, - steps, - modelName, - useVision, - verifierUseVision, - retries: retries + 1, - chunksSeen, - requestId, - domSettleTimeoutMs, - }); - } - } - - // Handle navigation if a new page is opened - if (method === "click") { - this.log({ - category: "action", - message: `Clicking element, checking for page navigation`, - level: 1, - }); - - // NAVIDNOTE: Should this happen before we wait for locator[method]? - const newOpenedTab = await Promise.race([ - new Promise((resolve) => { - this.context.once("page", (page) => resolve(page)); - setTimeout(() => resolve(null), 1_500); - }), - ]); - - this.log({ - category: "action", - message: `Clicked element, ${ - newOpenedTab ? "opened a new tab" : "no new tabs opened" - }`, - level: 1, - }); - - if (newOpenedTab) { - this.log({ - category: "action", - message: `New page detected (new tab) with URL: ${newOpenedTab.url()}`, - level: 1, - }); - await newOpenedTab.close(); - await this.page.goto(newOpenedTab.url()); - await this.page.waitForLoadState("domcontentloaded"); - await this._waitForSettledDom(domSettleTimeoutMs); - } - - // Wait for the network to be idle with timeout of 5s (will only wait if loading a new page) - // await this.waitForSettledDom(); - await Promise.race([ - this.page.waitForLoadState("networkidle"), - new Promise((resolve) => setTimeout(resolve, 5_000)), - ]).catch((e: Error) => { - this.log({ - category: "action", - message: `Network idle timeout hit`, - level: 1, - }); - }); - - this.log({ - category: "action", - message: `Finished waiting for (possible) page navigation`, - level: 1, - }); - - if (this.page.url() !== initialUrl) { - this.log({ - category: "action", - message: `New page detected with URL: ${this.page.url()}`, - level: 1, - }); - } - } - } else { - this.log({ - category: "action", - message: `Chosen method ${method} is invalid`, - level: 1, - }); - if (retries < 2) { - return this._act({ - action, - steps, - modelName: model, - useVision, - verifierUseVision, - retries: retries + 1, - chunksSeen, - requestId, - }); - } else { - if (this.enableCaching) { - this.llmProvider.cleanRequestCache(requestId); - } - - return { - success: false, - message: `Internal error: Chosen method ${method} is invalid`, - action: action, - }; - } - } - - let newSteps = - steps + - (!steps.endsWith("\n") ? "\n" : "") + - `## Step: ${response.step}\n` + - ` Element: ${elementText}\n` + - ` Action: ${response.method}\n` + - ` Reasoning: ${response.why}\n`; - - if (urlChangeString) { - newSteps += ` Result (Important): ${urlChangeString}\n\n`; - } - - let actionComplete = false; - if (response.completed) { - // Run action completion verifier - this.log({ - category: "action", - message: `Action marked as completed, Verifying if this is true...`, - level: 1, - }); - - let domElements: string | undefined = undefined; - let fullpageScreenshot: Buffer | undefined = undefined; - - if (verifierUseVision) { - try { - const screenshotService = new ScreenshotService( - this.page, - selectorMap, - this.verbose, - ); - - fullpageScreenshot = await screenshotService.getScreenshot( - true, - 15, - ); - } catch (e) { - this.log({ - category: "action", - message: `Error getting full page screenshot: ${e.message}\n. Trying again...`, - level: 1, - }); - - const screenshotService = new ScreenshotService( - this.page, - selectorMap, - this.verbose, - ); - - fullpageScreenshot = await screenshotService.getScreenshot( - true, - 15, - ); - } - } else { - ({ outputString: domElements } = await this.page.evaluate(() => { - return window.processAllOfDom(); - })); - } - - actionComplete = await verifyActCompletion({ - goal: action, - steps: newSteps, - llmProvider: this.llmProvider, - modelName: model, - screenshot: fullpageScreenshot, - domElements: domElements, - logger: this.logger, - requestId, - }); - - this.log({ - category: "action", - message: `Action completion verification result: ${actionComplete}`, - level: 1, - }); - } - - if (!actionComplete) { - this.log({ - category: "action", - message: `Continuing to next action step`, - level: 1, - }); - return this._act({ - action, - steps: newSteps, - modelName, - chunksSeen, - useVision, - verifierUseVision, - requestId, - domSettleTimeoutMs, - }); - } else { - this.log({ - category: "action", - message: `Action completed successfully`, - level: 1, - }); - await this._recordAction(action, response.step); - return { - success: true, - message: `Action completed successfully: ${steps}${response.step}`, - action: action, - }; - } - } catch (error) { - this.log({ - category: "action", - message: `Error performing action (Retries: ${retries}): ${error.message}\nTrace: ${error.stack}`, - level: 1, - }); - if (retries < 2) { - return this._act({ - action, - steps, - modelName, - useVision, - verifierUseVision, - retries: retries + 1, - chunksSeen, - requestId, - domSettleTimeoutMs, - }); - } - - await this._recordAction(action, ""); - if (this.enableCaching) { - this.llmProvider.cleanRequestCache(requestId); - } - - return { - success: false, - message: `Error performing action: ${error.message}`, - action: action, - }; - } - } - async act({ action, modelName, useVision = "fallback", + variables = {}, domSettleTimeoutMs, }: { action: string; modelName?: AvailableModel; useVision?: "fallback" | boolean; + variables?: Record; domSettleTimeoutMs?: number; }): Promise<{ success: boolean; @@ -1344,30 +801,35 @@ export class Stagehand { message: `Running act with action: ${action}, requestId: ${requestId}`, }); - return this._act({ - action, - modelName, - chunksSeen: [], - useVision, - verifierUseVision: useVision !== false, - requestId, - domSettleTimeoutMs, - }).catch((e) => { - this.logger({ - category: "act", - message: `Error acting: ${e.message}\nTrace: ${e.stack}`, - }); + if (variables) { + this.variables = { ...this.variables, ...variables }; + } - if (this.enableCaching) { - this.llmProvider.cleanRequestCache(requestId); - } + return this.actHandler + .act({ + action, + modelName, + chunksSeen: [], + useVision, + verifierUseVision: useVision !== false, + requestId, + variables, + previousSelectors: [], + skipActionCacheForThisStep: false, + domSettleTimeoutMs, + }) + .catch((e) => { + this.logger({ + category: "act", + message: `Error acting: ${e.message}\nTrace: ${e.stack}`, + }); - return { - success: false, - message: `Internal error: Error acting: ${e.message}`, - action: action, - }; - }); + return { + success: false, + message: `Internal error: Error acting: ${e.message}`, + action: action, + }; + }); } async extract({ diff --git a/lib/inference.ts b/lib/inference.ts index d9a58629..8f9e953b 100644 --- a/lib/inference.ts +++ b/lib/inference.ts @@ -84,6 +84,18 @@ export async function verifyActCompletion({ return response.completed; } +export function fillInVariables( + text: string, + variables: Record, +) { + let processedText = text; + Object.entries(variables).forEach(([key, value]) => { + const placeholder = `<|${key.toUpperCase()}|>`; + processedText = processedText.replace(placeholder, value); + }); + return processedText; +} + export async function act({ action, domElements, @@ -94,6 +106,7 @@ export async function act({ retries = 0, logger, requestId, + variables, }: { action: string; steps?: string; @@ -104,6 +117,7 @@ export async function act({ retries?: number; logger: (message: { category?: string; message: string }) => void; requestId: string; + variables?: Record; }): Promise<{ method: string; element: number; @@ -115,7 +129,7 @@ export async function act({ const llmClient = llmProvider.getClient(modelName, requestId); const messages: ChatMessage[] = [ buildActSystemPrompt(), - buildActUserPrompt(action, steps, domElements), + buildActUserPrompt(action, steps, domElements, variables), ]; const response = await llmClient.createChatCompletion({ @@ -133,10 +147,12 @@ export async function act({ }); const toolCalls = response.choices[0].message.tool_calls; + if (toolCalls && toolCalls.length > 0) { if (toolCalls[0].function.name === "skipSection") { return null; } + return JSON.parse(toolCalls[0].function.arguments); } else { if (retries >= 2) { diff --git a/lib/llm/AnthropicClient.ts b/lib/llm/AnthropicClient.ts index 296cb10d..5bf44e61 100644 --- a/lib/llm/AnthropicClient.ts +++ b/lib/llm/AnthropicClient.ts @@ -1,7 +1,7 @@ import Anthropic from "@anthropic-ai/sdk"; import { LLMClient, ChatCompletionOptions } from "./LLMClient"; import { zodToJsonSchema } from "zod-to-json-schema"; -import { LLMCache } from "./LLMCache"; +import { LLMCache } from "../cache/LLMCache"; export class AnthropicClient implements LLMClient { private client: Anthropic; @@ -50,7 +50,18 @@ export class AnthropicClient implements LLMClient { if (this.enableCaching) { const cachedResponse = await this.cache.get(cacheOptions, this.requestId); if (cachedResponse) { + this.logger({ + category: "llm_cache", + message: `LLM Cache hit - returning cached response`, + level: 1, + }); return cachedResponse; + } else { + this.logger({ + category: "llm_cache", + message: `LLM Cache miss - no cached response found`, + level: 1, + }); } } diff --git a/lib/llm/LLMCache.ts b/lib/llm/LLMCache.ts deleted file mode 100644 index edc275e1..00000000 --- a/lib/llm/LLMCache.ts +++ /dev/null @@ -1,332 +0,0 @@ -import * as fs from "fs"; -import * as path from "path"; -import * as crypto from "crypto"; - -interface CacheEntry { - timestamp: number; - response: any; - requestId: string; -} - -interface CacheStore { - [key: string]: CacheEntry; -} - -export class LLMCache { - private cacheDir: string; - private cacheFile: string; - private logger: (message: { - category?: string; - message: string; - level?: number; - }) => void; - private lockFile: string; - - private readonly CACHE_MAX_AGE_MS = 7 * 24 * 60 * 60 * 1000; // 1 week in milliseconds - private readonly CLEANUP_PROBABILITY = 0.01; // 1% chance - private readonly LOCK_TIMEOUT_MS = 1_000; - private lock_acquired = false; - private count_lock_acquire_failures = 0; - private request_id_to_used_hashes: { [key: string]: string[] } = {}; - - constructor( - logger: (message: { - category?: string; - message: string; - level?: number; - }) => void, - cacheDir: string = path.join(process.cwd(), "tmp", ".cache"), - cacheFile: string = "llm_calls.json", - ) { - this.logger = logger; - this.cacheDir = cacheDir; - this.cacheFile = path.join(cacheDir, cacheFile); - this.lockFile = path.join(cacheDir, "llm_cache.lock"); - this.ensureCacheDirectory(); - - // Handle process exit events (to make sure we release the lock) - this.setupProcessHandlers(); - } - - private setupProcessHandlers(): void { - const releaseLockAndExit = () => { - this.releaseLock(); - process.exit(); - }; - - process.on("exit", releaseLockAndExit); - process.on("SIGINT", releaseLockAndExit); - process.on("SIGTERM", releaseLockAndExit); - process.on("uncaughtException", (err) => { - this.logger({ - category: "llm_cache", - message: `Uncaught exception: ${err}`, - level: 2, - }); - if (this.lock_acquired) { - releaseLockAndExit(); - } - }); - } - - private ensureCacheDirectory(): void { - if (!fs.existsSync(this.cacheDir)) { - fs.mkdirSync(this.cacheDir, { recursive: true }); - } - } - - private createHash(data: any): string { - const hash = crypto.createHash("sha256"); - return hash.update(JSON.stringify(data)).digest("hex"); - } - - private sleep(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); - } - - private async acquireLock(): Promise { - const startTime = Date.now(); - while (Date.now() - startTime < this.LOCK_TIMEOUT_MS) { - try { - if (fs.existsSync(this.lockFile)) { - const lockAge = Date.now() - fs.statSync(this.lockFile).mtimeMs; - if (lockAge > this.LOCK_TIMEOUT_MS) { - fs.unlinkSync(this.lockFile); - } - } - - fs.writeFileSync(this.lockFile, process.pid.toString(), { flag: "wx" }); - this.count_lock_acquire_failures = 0; - this.lock_acquired = true; - return true; - } catch (error) { - await this.sleep(5); - } - } - this.logger({ - category: "llm_cache", - message: "Failed to acquire lock after timeout", - level: 2, - }); - this.count_lock_acquire_failures++; - if (this.count_lock_acquire_failures >= 3) { - this.logger({ - category: "llm_cache", - message: - "Failed to acquire lock 3 times in a row. Releasing lock manually.", - level: 1, - }); - this.releaseLock(); - } - return false; - } - - private releaseLock(): void { - try { - if (fs.existsSync(this.lockFile)) { - fs.unlinkSync(this.lockFile); - } - this.lock_acquired = false; - } catch (error) { - this.logger({ - category: "llm_cache", - message: `Error releasing lock: ${error}`, - level: 2, - }); - } - } - - private readCache(): CacheStore { - if (fs.existsSync(this.cacheFile)) { - return JSON.parse(fs.readFileSync(this.cacheFile, "utf-8")); - } - - return {}; - } - - private writeCache(cache: CacheStore): void { - try { - if (Math.random() < this.CLEANUP_PROBABILITY) { - this.cleanupStaleEntries(cache); - } - fs.writeFileSync(this.cacheFile, JSON.stringify(cache, null, 2)); - } finally { - this.releaseLock(); - } - } - - private cleanupStaleEntries(cache: CacheStore): void { - if (!this.acquireLock()) { - this.logger({ - category: "llm_cache", - message: "Failed to acquire lock for cleaning up cache", - level: 2, - }); - return; - } - - try { - const now = Date.now(); - let entriesRemoved = 0; - - for (const [hash, entry] of Object.entries(cache)) { - if (now - entry.timestamp > this.CACHE_MAX_AGE_MS) { - delete cache[hash]; - entriesRemoved++; - } - } - - if (entriesRemoved > 0) { - this.logger({ - category: "llm_cache", - message: `Cleaned up ${entriesRemoved} stale cache entries`, - level: 1, - }); - } - } catch (error) { - this.logger({ - category: "llm_cache", - message: `Error cleaning up stale cache entries: ${error}`, - level: 1, - }); - } finally { - this.releaseLock(); - } - } - - resetCache(): void { - if (!this.acquireLock()) { - this.logger({ - category: "llm_cache", - message: "Failed to acquire lock for resetting cache", - level: 2, - }); - return; - } - - try { - this.ensureCacheDirectory(); - fs.writeFileSync(this.cacheFile, "{}"); - } finally { - this.releaseLock(); - } - } - - async get(options: any, requestId: string): Promise { - if (!(await this.acquireLock())) { - this.logger({ - category: "llm_cache", - message: "Failed to acquire lock for getting cache", - level: 2, - }); - return null; - } - - try { - const hash = this.createHash(options); - const cache = this.readCache(); - - if (cache[hash]) { - this.logger({ - category: "llm_cache", - message: "Cache hit", - level: 1, - }); - this.request_id_to_used_hashes[requestId] ??= []; - this.request_id_to_used_hashes[requestId].push(hash); - return cache[hash].response; - } - return null; - } catch (error) { - this.logger({ - category: "llm_cache", - message: `Error getting cache: ${error}. Resetting cache.`, - level: 1, - }); - - this.resetCache(); - return null; - } finally { - this.releaseLock(); - } - } - - async deleteCacheForRequestId(requestId: string): Promise { - if (!(await this.acquireLock())) { - this.logger({ - category: "llm_cache", - message: "Failed to acquire lock for deleting cache", - level: 2, - }); - return; - } - - try { - const cache = this.readCache(); - - let entriesRemoved = []; - for (const hash of this.request_id_to_used_hashes[requestId] ?? []) { - if (cache[hash]) { - entriesRemoved.push(cache[hash]); - delete cache[hash]; - } - } - - this.logger({ - category: "llm_cache", - message: `Deleted ${entriesRemoved.length} cache entries for requestId ${requestId}`, - level: 1, - }); - - this.writeCache(cache); - } catch (exception) { - this.logger({ - category: "llm_cache", - message: `Error deleting cache for requestId ${requestId}: ${exception}`, - level: 1, - }); - } finally { - this.releaseLock(); - } - } - - async set(options: any, response: any, requestId: string): Promise { - if (!(await this.acquireLock())) { - this.logger({ - category: "llm_cache", - message: "Failed to acquire lock for setting cache", - level: 2, - }); - return; - } - - try { - const hash = this.createHash(options); - const cache = this.readCache(); - cache[hash] = { - response: response, - timestamp: Date.now(), - requestId, - }; - - this.writeCache(cache); - this.request_id_to_used_hashes[requestId] ??= []; - this.request_id_to_used_hashes[requestId].push(hash); - this.logger({ - category: "llm_cache", - message: "Cache miss - saved new response", - level: 1, - }); - } catch (error) { - this.logger({ - category: "llm_cache", - message: `Error setting cache: ${error}. Resetting cache.`, - level: 1, - }); - - this.resetCache(); - } finally { - this.releaseLock(); - } - } -} diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index 24e73442..20c9bebd 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -1,7 +1,7 @@ import { OpenAIClient } from "./OpenAIClient"; import { AnthropicClient } from "./AnthropicClient"; import { LLMClient } from "./LLMClient"; -import { LLMCache } from "./LLMCache"; +import { LLMCache } from "../cache/LLMCache"; export type AvailableModel = | "gpt-4o" diff --git a/lib/llm/OpenAIClient.ts b/lib/llm/OpenAIClient.ts index 59ea2358..3253f6f4 100644 --- a/lib/llm/OpenAIClient.ts +++ b/lib/llm/OpenAIClient.ts @@ -1,7 +1,7 @@ import OpenAI from "openai"; import { zodResponseFormat } from "openai/helpers/zod"; import { LLMClient, ChatCompletionOptions } from "./LLMClient"; -import { LLMCache } from "./LLMCache"; +import { LLMCache } from "../cache/LLMCache"; export class OpenAIClient implements LLMClient { private client: OpenAI; @@ -46,7 +46,18 @@ export class OpenAIClient implements LLMClient { if (this.enableCaching) { const cachedResponse = await this.cache.get(cacheOptions, this.requestId); if (cachedResponse) { + this.logger({ + category: "llm_cache", + message: `LLM Cache hit - returning cached response`, + level: 1, + }); return cachedResponse; + } else { + this.logger({ + category: "llm_cache", + message: `LLM Cache miss - no cached response found`, + level: 1, + }); } } diff --git a/lib/prompt.ts b/lib/prompt.ts index 642e33d1..a6e1e4c5 100644 --- a/lib/prompt.ts +++ b/lib/prompt.ts @@ -10,6 +10,7 @@ You are given: 1. the user's overall goal 2. the steps that you've taken so far 3. a list of active DOM elements in this chunk to consider to get closer to the goal. +4. Optionally, a list of variable names that the user has provided that you may use to accomplish the goal. To use the variables, you must use the special <|VARIABLE_NAME|> syntax. You have 2 tools that you can call: doAction, and skipSection. Do action only performs Playwright actions. Do not perform any other actions. @@ -103,8 +104,9 @@ export function buildActUserPrompt( action: string, steps = "None", domElements: string, + variables?: Record, ): ChatMessage { - const actUserPrompt = ` + let actUserPrompt = ` # My Goal ${action} @@ -115,6 +117,15 @@ ${steps} ${domElements} `; + if (variables) { + actUserPrompt += ` +# Variables +${Object.entries(variables) + .map(([key, value]) => `<|${key.toUpperCase()}|>`) + .join("\n")} +`; + } + return { role: "user", content: actUserPrompt, diff --git a/lib/types.ts b/lib/types.ts new file mode 100644 index 00000000..31380439 --- /dev/null +++ b/lib/types.ts @@ -0,0 +1,13 @@ +export class PlaywrightCommandException extends Error { + constructor(message: string) { + super(message); + this.name = "PlaywrightCommandException"; + } +} + +export class PlaywrightCommandMethodNotSupportedException extends Error { + constructor(message: string) { + super(message); + this.name = "PlaywrightCommandMethodNotSupportedException"; + } +} diff --git a/lib/utils.ts b/lib/utils.ts new file mode 100644 index 00000000..c9d20e4f --- /dev/null +++ b/lib/utils.ts @@ -0,0 +1,5 @@ +import crypto from "crypto"; + +export function generateId(operation: string) { + return crypto.createHash("sha256").update(operation).digest("hex"); +} diff --git a/lib/vision.ts b/lib/vision.ts index 1cdac5f2..5779e572 100644 --- a/lib/vision.ts +++ b/lib/vision.ts @@ -19,7 +19,7 @@ type NumberPosition = { export class ScreenshotService { private page: Page; - private selectorMap: Record; + private selectorMap: Record; private annotationBoxes: AnnotationBox[] = []; private numberPositions: NumberPosition[] = []; private isDebugEnabled: boolean; @@ -27,7 +27,7 @@ export class ScreenshotService { constructor( page: Page, - selectorMap: Record, + selectorMap: Record, verbose: 0 | 1 | 2, isDebugEnabled: boolean = false, ) { @@ -104,8 +104,8 @@ export class ScreenshotService { // }); const svgAnnotations = await Promise.all( - Object.entries(this.selectorMap).map(async ([id, selector]) => - this.createElementAnnotation(id, selector), + Object.entries(this.selectorMap).map(async ([id, selectors]) => + this.createElementAnnotation(id, selectors), ), ); @@ -135,19 +135,29 @@ export class ScreenshotService { private async createElementAnnotation( id: string, - selector: string, + selectors: string[], ): Promise { try { - const element = await this.page.locator(`xpath=${selector}`).first(); - const box = await element.boundingBox(); + let element = null; + + // Try each selector until one works + const selectorPromises: Promise[] = selectors.map( + async (selector) => { + try { + element = await this.page.locator(`xpath=${selector}`).first(); + const box = await element.boundingBox({ timeout: 5_000 }); + return box; + } catch (e) { + return null; + } + }, + ); + + const boxes = await Promise.all(selectorPromises); + const box = boxes.find((b) => b !== null); if (!box) { - this.log({ - category: "Debug", - message: `No bounding box for element ${id}`, - level: 2, - }); - return ""; + throw new Error(`Unable to create annotation for element ${id}`); } const scrollPosition = await this.page.evaluate(() => ({ @@ -180,8 +190,8 @@ export class ScreenshotService { `; } catch (error) { this.log({ - category: "Error", - message: `Failed to create annotation for element ${id}: ${error}`, + category: "Vision", + message: `Warning: Failed to create annotation for element ${id}: ${error}, trace: ${error.stack}`, level: 0, }); return "";