The TemplateAPI
class is a versatile superclass designed to facilitate the integration of various API-based language models into the lm-evaluation-harness framework. This guide will explain how to use and extend the TemplateAPI
class to implement your own API models. If your API implements the OpenAI API you can use the local-completions
or the local-chat-completions
(defined here) model types, which can also serve as examples of how to effectively subclass this template.
The TemplateAPI
class provides a template for creating API-based model implementations. It handles common functionalities such as:
- Tokenization (optional)
- Batch processing
- Caching
- Retrying failed requests
- Parsing API responses
To use this class, you typically need to subclass it and implement specific methods for your API.
When subclassing TemplateAPI
, you need to implement the following methods:
_create_payload
: Creates the JSON payload for API requests.parse_logprobs
: Parses log probabilities from API responses.parse_generations
: Parses generated text from API responses.headers
: Returns the headers for the API request.
You may also need to override other methods or properties depending on your API's specific requirements.
Note
Currently loglikelihood and MCQ based tasks (such as MMLU) are only supported for completion endpoints. Not for chat-completion — those that expect a list of dicts — endpoints! Completion APIs which support instruct tuned models can be evaluated with the --apply_chat_template
option in order to simultaneously evaluate models using a chat template format while still being able to access the model logits needed for loglikelihood-based tasks.
When initializing a TemplateAPI
instance or a subclass, you can provide several arguments to customize its behavior. Here's a detailed explanation of some important arguments:
-
model
orpretrained
(str):- The name or identifier of the model to use.
model
takes precedence overpretrained
when both are provided.
-
base_url
(str):- The base URL for the API endpoint.
-
tokenizer
(str, optional):- The name or path of the tokenizer to use.
- If not provided, it defaults to using the same tokenizer name as the model.
-
num_concurrent
(int):- Number of concurrent requests to make to the API.
- Useful for APIs that support parallel processing.
- Default is 1 (sequential processing).
-
timeout
(int, optional):- Timeout for API requests in seconds.
- Default is 30.
-
tokenized_requests
(bool):- Determines whether the input is pre-tokenized. Defaults to
True
. - Requests can be sent in either tokenized form (
list[list[int]]
) or as text (list[str]
, orstr
for batch_size=1). - For loglikelihood-based tasks, prompts require tokenization to calculate the context length. If
False
prompts are decoded back to text before being sent to the API. - Not as important for
generate_until
tasks. - Ignored for chat formatted inputs (list[dict...]) or if tokenizer_backend is None.
- Determines whether the input is pre-tokenized. Defaults to
-
tokenizer_backend
(str, optional):- Required for loglikelihood-based or MCQ tasks.
- Specifies the tokenizer library to use. Options are "tiktoken", "huggingface", or None.
- Default is "huggingface".
-
max_length
(int, optional):- Maximum length of input + output.
- Default is 2048.
-
max_retries
(int, optional):- Maximum number of retries for failed API requests.
- Default is 3.
-
max_gen_toks
(int, optional):- Maximum number of tokens to generate in completion tasks.
- Default is 256 or set in task yaml.
-
batch_size
(int or str, optional):- Number of requests to batch together (if the API supports batching).
- Can be an integer or "auto" (which defaults to 1 for API models).
- Default is 1.
-
seed
(int, optional):- Random seed for reproducibility.
- Default is 1234.
-
add_bos_token
(bool, optional):- Whether to add the beginning-of-sequence token to inputs (when tokenizing).
- Default is False.
-
custom_prefix_token_id
(int, optional):- Custom token ID to use as a prefix for inputs.
- If not provided, uses the model's default BOS or EOS token (if
add_bos_token
is True).
-
verify_certificate
(bool, optional):- Whether to validate the certificate of the API endpoint (if HTTPS).
- Default is True.
Example usage:
class MyAPIModel(TemplateAPI):
def __init__(self, **kwargs):
super().__init__(
model="my-model",
base_url="https://api.mymodel.com/v1/completions",
tokenizer_backend="huggingface",
num_concurrent=5,
max_retries=5,
batch_size=10,
**kwargs
)
# Implement other required methods...
When subclassing TemplateAPI
, you can override these arguments in your __init__
method to set default values specific to your API. You can also add additional (potentially user-specified) arguments as needed for your specific implementation.
The OpenAICompletionsAPI
and OpenAIChatCompletion
(here classes demonstrate how to implement API models using the TemplateAPI
class. Here's a breakdown of the key components:
@register_model("openai-completions")
class OpenAICompletionsAPI(LocalCompletionsAPI):
def __init__(
self,
base_url="https://api.openai.com/v1/completions",
tokenizer_backend="tiktoken",
**kwargs,
):
super().__init__(
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
)
@cached_property
def api_key(self):
key = os.environ.get("OPENAI_API_KEY", None)
if key is None:
raise ValueError(
"API key not found. Please set the OPENAI_API_KEY environment variable."
)
return key
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
generate=False,
gen_kwargs: Optional[dict] = None,
**kwargs,
) -> dict:
if generate:
# ... (implementation for generation)
else:
# ... (implementation for log likelihood)
@staticmethod
def parse_logprobs(
outputs: Union[Dict, List[Dict]],
tokens: List[List[int]] = None,
ctxlens: List[int] = None,
**kwargs,
) -> List[Tuple[float, bool]]:
# ... (implementation)
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
# ... (implementation)
The requests are initiated in the model_call
or the amodel_call
methods.
To implement your own API model:
- Subclass
TemplateAPI
or one of its subclasses (e.g.,LocalCompletionsAPI
). - Override the
__init__
method if you need to set specific parameters. - Implement the
_create_payload
andheader
methods to create the appropriate payload for your API. - Implement the
parse_logprobs
andparse_generations
methods to parse your API's responses. - Override the
api_key
property if your API requires authentication. - Override any other methods as necessary to match your API's behavior.
- Use the
@register_model
decorator to register your model with the framework (and import it inlm_eval/models/__init__.py
!). - Use environment variables for sensitive information like API keys.
- Properly handle batching and concurrent requests if supported by your API.