Skip to content

Commit

Permalink
fix datasets loading
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Sep 12, 2024
1 parent e55e6af commit 3e71593
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions moe_peft/tasks/qa_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def loading_data(
self, is_train: bool = True, path: Optional[str] = None
) -> List[InputData]:
data = hf_datasets.load_dataset(
"allenai/ai2_arc" if path is None else path, self.subject_
"allenai/ai2_arc" if path is None else path,
self.subject_,
trust_remote_code=True,
)["train" if is_train else "test"]
logging.info(f"Preparing data for {self.subject_}")
ret: List[InputData] = []
Expand Down Expand Up @@ -63,9 +65,10 @@ def __init__(self) -> None:
def loading_data(
self, is_train: bool = True, path: Optional[str] = None
) -> List[InputData]:
data = hf_datasets.load_dataset("google/boolq" if path is None else path)[
"train" if is_train else "validation"
]
data = hf_datasets.load_dataset(
"google/boolq" if path is None else path,
trust_remote_code=True,
)["train" if is_train else "validation"]
logging.info("Preparing data for BoolQ")
ret: List[InputData] = []
for data_point in data:
Expand All @@ -92,7 +95,9 @@ def loading_data(
self, is_train: bool = True, path: Optional[str] = None
) -> List[InputData]:
data = hf_datasets.load_dataset(
"allenai/openbookqa" if path is None else path, "main"
"allenai/openbookqa" if path is None else path,
"main",
trust_remote_code=True,
)["train" if is_train else "test"]
logging.info("Preparing data for OpenBookQA")
ret: List[InputData] = []
Expand Down Expand Up @@ -180,9 +185,10 @@ def __init__(self) -> None:
def loading_data(
self, is_train: bool = True, path: Optional[str] = None
) -> List[InputData]:
data = hf_datasets.load_dataset("Rowan/hellaswag" if path is None else path)[
"train" if is_train else "validation"
]
data = hf_datasets.load_dataset(
"Rowan/hellaswag" if path is None else path,
trust_remote_code=True,
)["train" if is_train else "validation"]
logging.info("Preparing data for HellaSwag")
ret: List[InputData] = []
for data_point in data:
Expand Down Expand Up @@ -241,9 +247,9 @@ def __init__(self) -> None:
def loading_data(
self, is_train: bool = True, path: Optional[str] = None
) -> List[InputData]:
data = hf_datasets.load_dataset("tau/commonsense_qa" if path is None else path)[
"train" if is_train else "validation"
]
data = hf_datasets.load_dataset(
"tau/commonsense_qa" if path is None else path, trust_remote_code=True
)["train" if is_train else "validation"]
logging.info("Preparing data for CommonSenseQA")
ret: List[InputData] = []
for data_point in data:
Expand Down Expand Up @@ -275,6 +281,7 @@ def loading_data(
data = hf_datasets.load_dataset(
"qiaojin/PubMedQA" if path is None else path,
"pqa_artificial" if is_train else "pqa_labeled",
trust_remote_code=True,
)["train"]
logging.info("Preparing data for PubMedQA")
ret: List[InputData] = []
Expand Down

0 comments on commit 3e71593

Please sign in to comment.