Skip to content

Commit

Permalink
Fix chat template tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun committed Aug 19, 2024
1 parent fbe98e3 commit 338fbb1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/alignment/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def maybe_insert_system_message(messages, tokenizer):
# chat template can be one of two attributes, we check in order
chat_template = tokenizer.chat_template
if chat_template is None:
chat_template = tokenizer.default_chat_template
chat_template = tokenizer.get_chat_template()

# confirm the jinja template refers to a system message before inserting
if "system" in chat_template or "<|im_start|>" in chat_template:
Expand Down
20 changes: 10 additions & 10 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,21 @@ def setUp(self):
)

def test_maybe_insert_system_message(self):
# does not accept system prompt. Use community checkpoint since it has no HF token requirement
mistral_tokenizer = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-Instruct-v0.3")
# accepts system prompt. use codellama since it has no HF token requirement
llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
# Chat template that does not accept system prompt. Use community checkpoint since it has no HF token requirement
tokenizer_sys_excl = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-Instruct-v0.3")
# Chat template that accepts system prompt
tokenizer_sys_incl = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}]
messages_sys_incl = [{"role": "system", "content": ""}, {"role": "user", "content": "Tell me a joke."}]

mistral_messages = deepcopy(messages_sys_excl)
llama_messages = deepcopy(messages_sys_excl)
maybe_insert_system_message(mistral_messages, mistral_tokenizer)
maybe_insert_system_message(llama_messages, llama_tokenizer)
messages_proc_excl = deepcopy(messages_sys_excl)
message_proc_incl = deepcopy(messages_sys_excl)
maybe_insert_system_message(messages_proc_excl, tokenizer_sys_excl)
maybe_insert_system_message(message_proc_incl, tokenizer_sys_incl)

# output from mistral should not have a system message, output from llama should
self.assertEqual(mistral_messages, messages_sys_excl)
self.assertEqual(llama_messages, messages_sys_incl)
self.assertEqual(messages_proc_excl, messages_sys_excl)
self.assertEqual(message_proc_incl, messages_sys_incl)

def test_sft(self):
dataset = self.dataset.map(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_default_chat_template_no_overwrite(self):
processed_tokenizer = get_tokenizer(model_args, DataArguments())

assert getattr(processed_tokenizer, "chat_template") is None
self.assertEqual(base_tokenizer.default_chat_template, processed_tokenizer.default_chat_template)
self.assertEqual(base_tokenizer.get_chat_template(), processed_tokenizer.get_chat_template())

def test_chatml_chat_template(self):
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
Expand Down

0 comments on commit 338fbb1

Please sign in to comment.