diff --git a/lm_eval/__main__.py b/lm_eval/__main__.py index ab68781939..50aedf6a96 100644 --- a/lm_eval/__main__.py +++ b/lm_eval/__main__.py @@ -128,6 +128,15 @@ def setup_parser() -> argparse.ArgumentParser: help="Limit the number of examples per task. " "If <1, limit is a percentage of the total number of examples.", ) + parser.add_argument( + "--examples", + "-E", + default=None, + type=str, + metavar="/path/to/json", + help="Examples to test. " + "Should be a json file which loads into a Python dictionary. E.g., {'mmlu_anatomy':[0,1],'mmlu_astronomy':[1,2,3]}.", + ) parser.add_argument( "--use_cache", "-c", @@ -309,10 +318,18 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ) if args.limit: + limit = args.limit eval_logger.warning( " --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." ) + if args.examples: + assert ( + args.limit is None + ), "If --examples is not None, then --limit must be None." + limit = None + with open(args.examples, "r") as json_file: + examples = json.load(json_file) if args.tasks is None: eval_logger.error("Need to specify task to evaluate.") @@ -388,7 +405,8 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: max_batch_size=args.max_batch_size, device=args.device, use_cache=args.use_cache, - limit=args.limit, + limit=limit, + examples=examples, check_integrity=args.check_integrity, write_out=args.write_out, log_samples=args.log_samples, @@ -445,7 +463,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: evaluation_tracker.recreate_metadata_card() print( - f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {limit}, num_fewshot: {args.num_fewshot}, " f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" ) print(make_table(results)) diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index 532e9e7ae6..58d78464fc 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -373,6 +373,7 @@ def build_all_requests( self, *, limit: Union[int, None] = None, + examples: Optional[List[int]] = None, rank: int = 0, world_size: int = 1, cache_requests: bool = False, @@ -425,7 +426,9 @@ def build_all_requests( limit = None doc_id_docs = list( - self.doc_iterator(rank=rank, limit=limit, world_size=world_size) + self.doc_iterator( + rank=rank, limit=limit, examples=examples, world_size=world_size + ) ) num_docs = len(doc_id_docs) @@ -676,15 +679,38 @@ def eval_docs(self) -> Union[datasets.Dataset, List[dict]]: ) def doc_iterator( - self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1 + self, + *, + rank: int = 0, + limit: Union[int, None] = None, + examples: Optional[List[int]] = None, + world_size: int = 1, ) -> Iterator[Tuple[int, Any]]: - limit = int(limit) if limit else None - doc_iterator = utils.create_iterator( - enumerate(self.eval_docs), - rank=int(rank), - limit=limit, - world_size=int(world_size), - ) + if examples: + n = self.eval_docs.to_pandas().shape[0] + assert all( + [e < n for e in examples] + ), f"Elements of --examples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}." + doc_iterator = utils.create_iterator( + enumerate( + datasets.Dataset.from_pandas( + self.eval_docs.to_pandas() + .iloc[examples, :] + .reset_index(drop=True) + ) + ), + rank=int(rank), + limit=None, # limit does not matter here since we are selecting samples directly + world_size=int(world_size), + ) + else: + limit = int(limit) if limit else None + doc_iterator = utils.create_iterator( + enumerate(self.eval_docs), + rank=int(rank), + limit=limit, + world_size=int(world_size), + ) return doc_iterator diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index d0c1a19a65..fdbce7c5fc 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -58,6 +58,7 @@ def simple_evaluate( rewrite_requests_cache: bool = False, delete_requests_cache: bool = False, limit: Optional[Union[int, float]] = None, + examples: Optional[dict] = None, bootstrap_iters: int = 100000, check_integrity: bool = False, write_out: bool = False, @@ -102,6 +103,8 @@ def simple_evaluate( Deletes all of the request cache if set to `True`. `None` if not desired. :param limit: int or float, optional Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. + :param examples: dictionary, optional + Dictionary indicating which examples should be tested in each task, e.g., {'mmlu_astronomy':[0,3,6],'mmlu_anatomy':[1,4,7,10]}. :param bootstrap_iters: Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed. :param check_integrity: bool @@ -139,6 +142,9 @@ def simple_evaluate( eval_logger.setLevel(getattr(logging, f"{verbosity}")) start_date = time.time() + if limit is not None and examples is not None: + raise ValueError("Either 'limit' or 'examples' must be None, but both are not None.") + if delete_requests_cache: eval_logger.info("Deleting requests cache...") delete_cache() @@ -302,6 +308,7 @@ def _adjust_config(task_dict): lm=lm, task_dict=task_dict, limit=limit, + examples=examples, cache_requests=cache_requests, rewrite_requests_cache=rewrite_requests_cache, bootstrap_iters=bootstrap_iters, @@ -361,6 +368,7 @@ def evaluate( lm: "LM", task_dict, limit: Optional[int] = None, + examples: Optional[dict] = None, cache_requests: bool = False, rewrite_requests_cache: bool = False, bootstrap_iters: Optional[int] = 100000, @@ -379,6 +387,8 @@ def evaluate( Dictionary of tasks. Tasks will be taken to have name type(task).config.task . :param limit: int, optional Limit the number of examples per task (only use this for testing) + :param examples: dictionary, optional + Dictionary indicating which examples should be tested in each task, e.g., {'mmlu_astronomy':[0,3,6],'mmlu_anatomy':[1,4,7,10]}. :param bootstrap_iters: Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations. :param write_out: bool @@ -400,6 +410,9 @@ def evaluate( eval_logger.setLevel(getattr(logging, f"{verbosity}")) + if limit is not None and examples is not None: + raise ValueError("Either 'limit' or 'examples' must be None, but both are not None.") + # tracks all Instances/requests a model must generate output on. requests = defaultdict(list) # stores the amount to pad out reqs per req. type so that @@ -443,6 +456,7 @@ def evaluate( limits.append(limit) task.build_all_requests( limit=limit, + examples=examples[task_output.task_name] if examples is not None else examples, rank=lm.rank, world_size=lm.world_size, cache_requests=cache_requests, @@ -527,9 +541,16 @@ def evaluate( # iterate over different filters used for filter_key in task.instances[0].filtered_resps.keys(): doc_iterator = task.doc_iterator( - rank=RANK, limit=limit, world_size=WORLD_SIZE + rank=RANK, + limit=limit, + examples=examples[task_output.task_name] if examples is not None else examples, + world_size=WORLD_SIZE, ) for doc_id, doc in doc_iterator: + if examples: + doc_id_true = examples[task_output.task_name][doc_id] + else: + doc_id_true = doc_id requests = instances_by_doc_id[doc_id] metrics = task.process_results( doc, [req.filtered_resps[filter_key] for req in requests] @@ -537,7 +558,7 @@ def evaluate( if log_samples: target = task.doc_to_target(doc) example = { - "doc_id": doc_id, + "doc_id": doc_id_true, "doc": doc, "target": target, "arguments": [req.args for req in requests],