Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding OLMo #1827

Merged
merged 12 commits into from
Nov 13, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Every model is written from scratch to maximize performance and remove layers of
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) |
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) |
Expand Down
71 changes: 69 additions & 2 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,21 @@ def mlp_class(self) -> Type:
@property
def norm_class(self) -> Type:
# `self.norm_class_name` cannot be the type to keep the config serializable
if self.norm_class_name == "RMSNorm":
from functools import partial

from functools import partial

if self.norm_class_name == "RMSNorm":

from litgpt.model import RMSNorm

return partial(RMSNorm, add_unit_offset="Gemma" in self.name)

if self.norm_class_name == "LayerNorm" and "OLMo" in self.name:
# this makes it equivalent to `torch.nn.functional.layer_norm`
# that is used by OLMo
# Table 5 caption in the OLMo paper shows this - https://aclanthology.org/2024.acl-long.841
return partial(torch.nn.LayerNorm, elementwise_affine=False)

return getattr(torch.nn, self.norm_class_name)


Expand Down Expand Up @@ -722,6 +731,64 @@ def norm_class(self) -> Type:
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
)

#################
# Allen AI OLMo
#################
olmo = [
# https://huggingface.co/allenai/OLMo-1B-hf/blob/main/config.json
dict(
name="OLMo-1B-hf",
hf_config=dict(org="allenai", name="OLMo-1B-hf"),
vocab_size=50280,
padded_vocab_size=50304,
block_size=2048,
n_embd=2048,
n_layer=16,
n_head=16,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="LayerNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
),
# https://huggingface.co/allenai/OLMo-7B-hf/blob/main/config.json
dict(
name="OLMo-7B-hf",
hf_config=dict(org="allenai", name="OLMo-7B-hf"),
vocab_size=50280,
padded_vocab_size=50304,
block_size=2048,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="LayerNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
),
# https://huggingface.co/allenai/OLMo-7B-Instruct-hf/blob/main/config.json
dict(
name="OLMo-7B-Instruct-hf",
hf_config=dict(org="allenai", name="OLMo-7B-Instruct-hf"),
vocab_size=50280,
padded_vocab_size=50304,
block_size=2048,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="LayerNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
),
]

configs.extend(olmo)

###############
# Google Gemma
###############
Expand Down
8 changes: 8 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ def apply(self, prompt: str, **kwargs: str) -> str:



class OLMo(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n"


# Maps prompt style names to PromptStyle classes
prompt_styles: Dict[str, Type[PromptStyle]] = {
# Dataset-specific prompt styles
Expand All @@ -298,6 +303,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"tinyllama": TinyLlama,
"gemma": Gemma,
"llama3": Llama3,
"olmo": OLMo,
}


Expand Down Expand Up @@ -334,6 +340,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return TinyLlama()
if re.search(r"(Code)?Gemma.*-it", model_name):
return Gemma()
if re.search(r"OLMo.*-hf", model_name):
return OLMo()
return Default()


Expand Down
5 changes: 4 additions & 1 deletion litgpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ def encode(

if eos and (not tokens or tokens[-1] != self.eos_id):
tokens = tokens + [self.eos_id]

# if the processor misbehaves and adds `eos` token no matter what
elif tokens and tokens[-1] == self.eos_id:
tokens = tokens[:-1]

if max_length > 0:
tokens = tokens[:max_length]
return torch.tensor(tokens, dtype=torch.int, device=device)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.olmo import OlmoConfig, OlmoForCausalLM

from litgpt import GPT, Config
from litgpt.scripts.convert_lit_checkpoint import (
Expand Down Expand Up @@ -192,6 +193,48 @@ def test_against_mixtral():
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf"))
def test_against_olmo(model_name):
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
intermediate_size=86,
)
T = 5
theirs_config = OlmoConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_query_groups,
max_positional_embeddings=T,
attention_bias=ours_config.bias,
rope_theta=ours_config.rope_base,
tie_word_embeddings=(model_name == "OLMo-1B-hf"),
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

ours_model = GPT(ours_config)
# tie weights
ours_model.lm_head.weight = ours_model.transformer.wte.weight
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=(model_name == "OLMo-1B-hf"))
theirs_model = OlmoForCausalLM(theirs_config)
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)

@torch.inference_mode()
def test_against_original_open_llama_3b():
Expand Down
58 changes: 58 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.olmo import OlmoConfig, OlmoForCausalLM

import litgpt.config as config_module
from litgpt.model import batched_index_copy_
Expand Down Expand Up @@ -551,6 +552,63 @@ def test_against_hf_mixtral():
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_olmo(model_name, device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
intermediate_size=86,
)
T = 5
theirs_config = OlmoConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_query_groups,
max_positional_embeddings=T,
attention_bias=ours_config.bias,
rope_theta=ours_config.rope_base,
tie_word_embeddings=(model_name == "OLMo-1B-hf"),
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = OlmoForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)

@torch.inference_mode()
@pytest.mark.parametrize(
Expand Down
4 changes: 4 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219)
Expand Down Expand Up @@ -54,6 +55,9 @@ litgpt download list
The output is shown below:

```
allenai/OLMo-1B-hf
allenai/OLMo-7B-hf
allenai/OLMo-7B-Instruct-hf
codellama/CodeLlama-13b-hf
codellama/CodeLlama-13b-Instruct-hf
codellama/CodeLlama-13b-Python-hf
Expand Down