Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow fewshots for multimodal tasks #2450

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions lm_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter

# TODO: repeat almost the same 3 times
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_text", None) is not None
Expand Down Expand Up @@ -50,6 +51,17 @@ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
else:
self.doc_to_choice = self.task.doc_to_choice

if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_image", None) is not None
):
self.doc_to_image = partial(
self.task.doc_to_image,
doc_to_image=self.config.fewshot_config.get("doc_to_image", None),
)
else:
self.doc_to_image = self.task.doc_to_image

self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
if not isinstance(self.docs, datasets.Dataset):
Expand All @@ -74,9 +86,16 @@ def get_context(self, doc, num_fewshot):
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = ""

# keep files for fewshots
multimodal_args = {}

for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
if self.config.doc_to_image:
doc_image = self.doc_to_image(doc)

labeled_examples += (
doc_content
if self.config.doc_to_choice is None or isinstance(doc_content, str)
Expand All @@ -94,7 +113,10 @@ def get_context(self, doc, num_fewshot):
)
labeled_examples += self.fewshot_delimiter

return labeled_examples
if self.config.doc_to_image:
multimodal_args.setdefault("visual", []).extend(doc_image)

return labeled_examples, multimodal_args

def get_chat_context(
self,
Expand All @@ -103,6 +125,10 @@ def get_chat_context(
fewshot_as_multiturn: bool = False,
):
chat_history = []

# keep files for fewshots
multimodal_args = {}

# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
num_fewshot + 1
Expand All @@ -120,6 +146,9 @@ def get_chat_context(
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
if self.config.doc_to_image:
doc_image = self.doc_to_image(doc)

chat_history.append(
{
"role": "user",
Expand All @@ -140,13 +169,16 @@ def get_chat_context(
else str(self.doc_to_choice(doc)[doc_target]),
}
)
if self.config.doc_to_image:
multimodal_args.setdefault("visual", []).extend(doc_image)
else:
# get fewshot context as one user turn
chat_history.append(
{"role": "user", "content": self.get_context(doc, num_fewshot)}
)
context, multimodal_args = self.get_context(doc, num_fewshot)
chat_history.append({"role": "user", "content": context})
if self.config.doc_to_image:
multimodal_args.setdefault("visual", []).extend(doc_image)

return chat_history
return chat_history, multimodal_args

def sample(self, n):
"""
Expand Down
50 changes: 27 additions & 23 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def build_all_requests(
total=num_docs,
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
fewshot_ctx, multimodal_args = self.fewshot_context(
doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
system_instruction,
Expand All @@ -448,6 +448,7 @@ def build_all_requests(
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
multimodal_args=multimodal_args,
metadata=(self.config["task"], doc_id, self.config.repeats),
)

Expand Down Expand Up @@ -1044,6 +1045,9 @@ def fewshot_context(
The fewshot context.
"""

# empty dict by default
multimodal_args = {}

if apply_chat_template:
labeled_examples = []
else:
Expand Down Expand Up @@ -1075,31 +1079,33 @@ def fewshot_context(
# if few-shot - append examples after the system prompt
if num_fewshot > 0:
if apply_chat_template:
labeled_examples.extend(
self.sampler.get_chat_context(
doc, num_fewshot, fewshot_as_multiturn
)
fewshots, multimodal_args = self.sampler.get_chat_context(
doc, num_fewshot, fewshot_as_multiturn
)
labeled_examples.extend(fewshots)
else:
labeled_examples += self.sampler.get_context(doc, num_fewshot)
fewshots, multimodal_args = self.sampler.get_context(doc, num_fewshot)
labeled_examples += fewshots

example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
return chat_template(labeled_examples)
return chat_template(labeled_examples), multimodal_args
if isinstance(example, str):
self.append_target_question(
labeled_examples, example, fewshot_as_multiturn
)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
multimodal_args_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
multimodal_args_list.extend([multimodal_args])
return labeled_examples_list, multimodal_args_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
Expand All @@ -1112,20 +1118,22 @@ def fewshot_context(
labeled_examples, str(example), fewshot_as_multiturn
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
return chat_template(labeled_examples), multimodal_args
else:
if self.multiple_input:
return labeled_examples
return labeled_examples, multimodal_args
if isinstance(example, str):
return labeled_examples + example
return labeled_examples + example, multimodal_args
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
return [labeled_examples + ex for ex in example], [
multimodal_args for ex in example
]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
return labeled_examples + choices[example], multimodal_args
else:
return labeled_examples + str(example)
return labeled_examples + str(example), multimodal_args

def apply_filters(self):
"""Iterates over FilterEnsembles and applies them to instances"""
Expand Down Expand Up @@ -1299,7 +1307,7 @@ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
return None

def construct_requests(
self, doc: dict, ctx: str, **kwargs
self, doc: dict, ctx: str, multimodal_args: dict = {}, **kwargs
) -> Union[List[Instance], Instance]:
aux_arguments = None

Expand Down Expand Up @@ -1335,20 +1343,16 @@ def construct_requests(
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, deepcopy(self.config.generation_kwargs))

multimodal_arg = {}
if (
self.config.doc_to_image
): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = {
**multimodal_arg,
**{"visual": self.doc_to_image(doc)},
}
multimodal_args.setdefault("visual", []).extend(self.doc_to_image(doc))

if bool(multimodal_arg):
if bool(multimodal_args):
if isinstance(arguments, list):
arguments = [arg + (multimodal_arg,) for arg in arguments]
arguments = [arg + (multimodal_args,) for arg in arguments]
else:
arguments = arguments + (multimodal_arg,)
arguments = arguments + (multimodal_args,)

if self.OUTPUT_TYPE == "multiple_choice":
request_list = [
Expand Down
Loading