Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow methodologies to configure the test name #1251

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions src/server/bandit/banditData.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { isProd } from '../lib/env';
import * as AWS from 'aws-sdk';
import { buildReloader, ValueProvider } from '../utils/valueReloader';
import { BannerTest, Channel, EpicTest, Test, Variant } from '../../shared/types';
import { BannerTest, Channel, EpicTest, Methodology, Test, Variant } from '../../shared/types';
import { z } from 'zod';
import { logError } from '../utils/logging';
import { putMetric } from '../utils/cloudwatch';
Expand Down Expand Up @@ -68,25 +68,24 @@ export interface BanditData {
bestVariants: BanditVariantData[]; // will contain more than 1 variant if there is a tie
}

function getDefaultWeighting<V extends Variant, T extends Test<V>>(test: T): BanditData {
function getDefaultWeighting(test: BanditTestConfig): BanditData {
// No samples yet, set all means to zero to allow random selection
return {
testName: test.name,
bestVariants: test.variants.map((variant) => ({
variantName: variant.name,
testName: test.testName,
bestVariants: test.variantNames.map((variantName) => ({
variantName,
mean: 0,
})),
};
}

function calculateMeanPerVariant<V extends Variant, T extends Test<V>>(
function calculateMeanPerVariant(
samples: TestSample[],
test: T,
test: BanditTestConfig,
): BanditVariantData[] {
const allVariantSamples = samples.flatMap((sample) => sample.variants);
const variantNames = test.variants.map((variant) => variant.name);

return variantNames.map((variantName) => {
return test.variantNames.map((variantName) => {
const variantSamples = allVariantSamples.filter(
(variantSample) => variantSample.variantName === variantName,
);
Expand Down Expand Up @@ -114,18 +113,16 @@ function calculateBestVariants(variantMeans: BanditVariantData[]): BanditVariant
return variantMeans.filter((variant) => variant.mean === highestMean);
}

async function buildBanditDataForTest<V extends Variant, T extends Test<V>>(
test: T,
): Promise<BanditData> {
if (test.variants.length === 0) {
async function buildBanditDataForTest(test: BanditTestConfig): Promise<BanditData> {
if (test.variantNames.length === 0) {
// No variants have been added to the test yet
return {
testName: test.name,
testName: test.testName,
bestVariants: [],
};
}

const samples = await getBanditSamplesForTest(test.name, test.channel);
const samples = await getBanditSamplesForTest(test.testName, test.channel);

if (samples.length < MINIMUM_SAMPLES) {
return getDefaultWeighting(test);
Expand All @@ -135,27 +132,42 @@ async function buildBanditDataForTest<V extends Variant, T extends Test<V>>(
const bestVariants = calculateBestVariants(variantMeans);

return {
testName: test.name,
testName: test.testName,
bestVariants,
};
}

function hasBanditMethodology<V extends Variant, T extends Test<V>>(test: T): boolean {
return !!test.methodologies?.find((method) => method.name === 'EpsilonGreedyBandit');
interface BanditTestConfig {
testName: string;
channel: Channel;
variantNames: string[];
}

// Return config for each bandit methodology in this test
function getBanditTestConfigs<V extends Variant, T extends Test<V>>(test: T): BanditTestConfig[] {
const bandits: Methodology[] = test.methodologies?.filter(
(method) => method.name === 'EpsilonGreedyBandit',
);
return bandits.map((method) => ({
testName: method.testName ?? test.name, // if the methodology should be tracked with a different name then use that
channel: test.channel,
variantNames: test.variants.map((v) => v.name),
}));
}

function buildBanditData(
epicTestsProvider: ValueProvider<EpicTest[]>,
bannerTestsProvider: ValueProvider<BannerTest[]>,
): Promise<BanditData[]> {
const banditTests = [...epicTestsProvider.get(), ...bannerTestsProvider.get()].filter(
hasBanditMethodology,
);
const allTests = [...epicTestsProvider.get(), ...bannerTestsProvider.get()];
// For each test, get any bandit methodologies so that we can fetch sample data
const banditTests: BanditTestConfig[] = allTests.flatMap((test) => getBanditTestConfigs(test));

return Promise.all(
banditTests.map((test) =>
buildBanditDataForTest(test).catch((error) => {
banditTests.map((banditTestConfig) =>
buildBanditDataForTest(banditTestConfig).catch((error) => {
logError(
`Error fetching bandit samples for test ${test.name} from Dynamo: ${error.message}`,
`Error fetching bandit samples for test ${banditTestConfig.name} from Dynamo: ${error.message}`,
);
putMetric('bandit-data-load-error');
return Promise.reject(error);
Expand Down
13 changes: 10 additions & 3 deletions src/server/lib/ab.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ describe('selectVariant', () => {
expect(result?.test.name).toEqual(test.name);
});

it('should return same test name if one methodology is configured', () => {
it('should return same test name if the methodology is configured with no testName', () => {
const testWithMethodology: EpicTest = {
...test,
methodologies: [{ name: 'ABTest' }],
Expand All @@ -149,10 +149,17 @@ describe('selectVariant', () => {
expect(result?.test.name).toEqual(test.name);
});

it('should return extended test name if one than one methodology is configured', () => {
it('should return extended test name if the methodology is configured with a testName', () => {
const testWithMethodology: EpicTest = {
...test,
methodologies: [{ name: 'ABTest' }, { name: 'EpsilonGreedyBandit', epsilon: 0.5 }],
methodologies: [
{ name: 'ABTest', testName: 'example-1_ABTest' },
{
name: 'EpsilonGreedyBandit',
epsilon: 0.5,
testName: 'example-1_EpsilonGreedyBandit-0.5',
},
],
};
const result = selectVariant(testWithMethodology, 1, []);
expect(result?.test.name).toBe('example-1_EpsilonGreedyBandit-0.5');
Expand Down
12 changes: 2 additions & 10 deletions src/server/lib/ab.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,6 @@ const selectVariantWithMethodology = <V extends Variant, T extends Test<V>>(
return selectVariantUsingMVT<V, T>(test, mvtId);
};

const addMethodologyToTestName = (testName: string, methodology: Methodology): string => {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer necessary as we can use the testName from the methodology

if (methodology.name === 'EpsilonGreedyBandit') {
return `${testName}_EpsilonGreedyBandit-${methodology.epsilon}`;
} else {
return `${testName}_ABTest`;
}
};

/**
* Selects a variant from the test based on any configured methodologies.
* Defaults to an AB test.
Expand Down Expand Up @@ -120,10 +112,10 @@ export const selectVariant = <V extends Variant, T extends Test<V>>(
const methodology =
test.methodologies[getRandomNumber(test.name, mvtId) % test.methodologies.length];

// Add the methodology to the test name so that we can track them separately
// if the methodology should be tracked with a different name then use that
const testWithNameExtension = {
...test,
name: addMethodologyToTestName(test.name, methodology),
name: methodology.testName ?? test.name,
};
const variant = selectVariantWithMethodology<V, T>(
testWithNameExtension,
Expand Down
9 changes: 5 additions & 4 deletions src/shared/types/abTests/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ const epsilonGreedyMethodologySchema = z.object({
name: z.literal('EpsilonGreedyBandit'),
epsilon: z.number(),
});
const methodologySchema = z.discriminatedUnion('name', [
abTestMethodologySchema,
epsilonGreedyMethodologySchema,
]);
const methodologySchema = z.intersection(
z.discriminatedUnion('name', [abTestMethodologySchema, epsilonGreedyMethodologySchema]),
// each methodology may have an optional testName, which should be used for tracking
z.object({ testName: z.string().optional() }),
);
export type Methodology = z.infer<typeof methodologySchema>;

export interface Variant {
Expand Down
Loading