Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Few-shots from the retrieved dataset #391

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
58 changes: 57 additions & 1 deletion prompt2model/dataset_generator/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,38 @@ async def generate_responses(
)
return responses

def create_retrieved_data_fewshot_string(
self, retrieved_dataset: Dataset, n_shots: int = 3
) -> str:
"""A method to sample instances from the dataset retrieved.

Args:
retrieved_dataset: the train set of the dataset from the dataset retriever
n_shots: how many instances is used as an example

Returns:
A string of pairs of examples sourced from the dataset retrieved.
formatted as if it was a string from the user
"""
sample_dataset = retrieved_dataset[
random.sample(range(retrieved_dataset.num_rows), n_shots)
]
return "\n\n".join(
[
f"input=\"{sample_dataset['input_col'][i]}\"\n"
+ f"output=\"{sample_dataset['output_col'][i]}\""
for i in range(n_shots)
]
)

def generate_dataset_split(
self,
prompt_spec: PromptSpec,
num_examples: int,
split: DatasetSplit = DatasetSplit.TRAIN,
retrieved_dataset: Dataset = None,
few_shot_method: str = "user",
n_shots: int = 3,
) -> Dataset:
"""Generates a dataset split using API-based LMs.

Expand All @@ -408,16 +435,39 @@ def generate_dataset_split(
Args:
prompt_spec: PromptParser to be used for generating examples.
num_examples: The number of examples to be generated.
retrieved_dataset: The human-annotated dataset retrieved
few_shot_method: a string that indicates fewshot source.
'user' means few-shots is from user examples,
'retrieved_fixed' is using the same examples everytime,
'retrieved_swapout' is using different examples everytime
n_shots: how many instances is used as an example

Returns:
The generated dataset split.
"""
if (
few_shot_method in ["retrieved_fixed", "retrieved_swapout"]
and retrieved_dataset is None
):
raise Exception(
f"retrieved_dataset can't be None if few shot is '{few_shot_method}'"
)

all_generated_examples: list[Example] = []
generated_examples: list[Example] = []

pbar = tqdm(total=num_examples, desc="Generating examples")
chat_api = api_tools.default_api_agent

if few_shot_method == "user":
few_shot_example_string = prompt_spec.examples
elif few_shot_method in ["retrieved_fixed", "retrieved_swapout"]:
few_shot_example_string = self.create_retrieved_data_fewshot_string(
retrieved_dataset, n_shots
)
else:
raise Exception(f"'{few_shot_method}' is not a recognized few shot method")

while len(generated_examples) < num_examples:
if self.max_api_calls and self.api_call_counter >= self.max_api_calls:
logger.warning("Maximum number of API calls reached.")
Expand All @@ -430,7 +480,7 @@ def generate_dataset_split(
prompts = [
self.construct_prompt(
instruction=prompt_spec.instruction,
few_shot_example_string=prompt_spec.examples,
few_shot_example_string=few_shot_example_string,
generated_examples=generated_examples,
)
for _ in range(batch_size)
Expand Down Expand Up @@ -460,6 +510,12 @@ def generate_dataset_split(

pbar.update(len(generated_examples) - prev_length)

# regenerate few_shot_example_string if few_shot_method is retrieved_swapout
if few_shot_method == "retrieved_swapout":
few_shot_example_string = self.create_retrieved_data_fewshot_string(
retrieved_dataset, n_shots
)

if len(generated_examples) >= num_examples:
generated_examples = generated_examples[:num_examples]

Expand Down