Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llama3 tasks #2556

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions lm_eval/tasks/llama3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Task-name

### Paper

Title: ``

Abstract: ``


Homepage: ``


### Citation

```

```

### Groups, Tags, and Tasks

#### Groups



#### Subgroups


### Tasks

* `llama_arc_challenge`: 25-shot multiple-choice ARC challenge.
* `mgsm_chat`: 0-shot mgsm benchmark. Use with chat-template.

### Checklist

For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?


If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
24 changes: 24 additions & 0 deletions lm_eval/tasks/llama3/base/arc_challenge.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
tag:
- llama
task: llama_arc_challenge
dataset_path: allenai/ai2_arc
dataset_name: ARC-Challenge
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
fewshot_split: train
doc_to_text: "Question: {{question.strip()}}\nA. {{choices.text[0]}}\nB. {{choices.text[1]}}\nC. {{choices.text[2]}}{% if choices.text|length > 3 %}\nD. {{choices.text[3]}}{% endif %}\nAnswer:"
fewshot_delimiter: "\n\n"
doc_to_target: "{{ 'ABCD'[answerKey|int - 1] if answerKey|string in '1234' else answerKey }}"
doc_to_choice: "{{ choices.label|map('replace', '1', 'A')|map('replace', '2', 'B')|map('replace', '3', 'C')|map('replace', '4', 'D')|list if choices.label[0] in '1234' else choices.label }}"
num_fewshot: 25
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
15 changes: 15 additions & 0 deletions lm_eval/tasks/llama3/base/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import datasets


def process_arc_c_docs(dataset: datasets.Dataset) -> datasets.Dataset:
COLUMNS = dataset.column_names

def map_(doc):
doc["doc_to_text"] = doc["input_final_prompts"][0].strip()[:-2].strip()
doc["doc_to_choice"] = [
x.replace("Answer:", "").strip() for x in doc["output_choice_completions"]
]
doc["doc_to_target"] = doc["input_correct_responses"][0].strip()[-1]
return doc

return dataset.map(map_, remove_columns=COLUMNS)
45 changes: 45 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm_chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
tag: llama3
task: mgsm_chat
dataset_path: meta-llama/Llama-3.2-3B-Instruct-evals
dataset_name: Llama-3.2-3B-Instruct-evals__mgsm__details
output_type: generate_until
test_split: latest
doc_to_text: "{{
input_final_prompts
|first
|replace('<|eot_id|><|start_header_id|>assistant<|end_header_id|>', '')
|replace('<|start_header_id|>', '')
|replace('<|end_header_id|>', '')
|replace('<|eot_id|>', '')
|replace('^user', '')
|trim
}}"
doc_to_target: "input_correct_responses"
process_results: !function utils.process_results_mgsm
generation_kwargs:
until: []
do_sample: false
temperature: 0.0
max_gen_toks: 2048
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "Answer: (\\-?[0-9\\.\\,]+)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: regex
group_select: -1
regex_pattern: "Answer: (-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: take_first
- function: remove_whitespace
- function: take_first
metadata:
version: 0.0
15 changes: 15 additions & 0 deletions lm_eval/tasks/llama3/instruct/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List

from lm_eval.api.metrics import exact_match_fn


def process_results_mgsm(doc, prediction):
gold: List = doc["input_correct_responses"]
return {
"exact_match": int(
exact_match_fn(
predictions=prediction * len(gold), references=gold, ignore_case=True
)["exact_match"]
> 0
)
}
Loading