Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to add an assistant_prefix #2545

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lm_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def get_chat_context(
doc,
num_fewshot,
fewshot_as_multiturn: bool = False,
assistant_prefix: str = None,
):
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
Expand Down Expand Up @@ -145,6 +146,8 @@ def get_chat_context(
chat_history.append(
{"role": "user", "content": self.get_context(doc, num_fewshot)}
)
if assistant_prefix:
chat_history.append({"role": "assistant", "content": assistant_prefix})

return chat_history

Expand Down
27 changes: 22 additions & 5 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
assistant_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
Expand Down Expand Up @@ -381,6 +382,7 @@ def build_all_requests(
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
assistant_prefix: Optional[str] = None,
tokenizer_name: str = "",
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
Expand Down Expand Up @@ -442,6 +444,7 @@ def build_all_requests(
apply_chat_template,
fewshot_as_multiturn,
chat_template,
assistant_prefix,
)

# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
Expand Down Expand Up @@ -1000,6 +1003,7 @@ def append_target_question(
labeled_examples: List[Dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
assistant_prefix: Optional[str] = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
Expand All @@ -1015,6 +1019,7 @@ def append_target_question(
else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question})
labeled_examples.append({"role": "assistant", "content": assistant_prefix})

@utils.positional_deprecated
def fewshot_context(
Expand All @@ -1025,6 +1030,7 @@ def fewshot_context(
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
assistant_prefix: Optional[str] = None,
) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
Expand Down Expand Up @@ -1078,7 +1084,7 @@ def fewshot_context(
if apply_chat_template:
labeled_examples.extend(
self.sampler.get_chat_context(
doc, num_fewshot, fewshot_as_multiturn
doc, num_fewshot, fewshot_as_multiturn, assistant_prefix
)
)
else:
Expand All @@ -1090,27 +1096,38 @@ def fewshot_context(
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples, example, fewshot_as_multiturn
labeled_examples,
example,
fewshot_as_multiturn,
self.config.assistant_prefix,
)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
self.append_target_question(
chat, ex, fewshot_as_multiturn, self.config.assistant_prefix
)
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(
labeled_examples, choices[example], fewshot_as_multiturn
labeled_examples,
choices[example],
fewshot_as_multiturn,
self.config.assistant_prefix,
)
else:
self.append_target_question(
labeled_examples, str(example), fewshot_as_multiturn
labeled_examples,
str(example),
fewshot_as_multiturn,
self.config.assistant_prefix,
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
Expand Down
12 changes: 9 additions & 3 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,21 +1330,27 @@ def _collate(req: Tuple[str, dict]):

return res

def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
try:
chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
except jinja2.exceptions.TemplateError:
eval_logger.warning(
"Failed to apply chat template. removing the system role in chat history."
)
chat_history = [msg for msg in chat_history if msg["role"] != "system"]
chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)

return chat_templated
Expand Down
Loading