Skip to content

Commit

Permalink
Merge pull request #161 from allenai/add-no-grad
Browse files Browse the repository at this point in the history
Add torch.no_grad, fix greedy_until bug
  • Loading branch information
OyvindTafjord authored Apr 19, 2024
2 parents cae37ad + 6f0389c commit c3eb82e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Added `torch.no_grad()` around model calls in `language_model.py`
- Prevent crashes with more robust stop token for `greedy_until` in `language_model.py`

## [v1.0.0rc0](https://github.com/allenai/catwalk/releases/tag/v1.0.0rc0) - 2023-12-19

### Added
Expand Down
22 changes: 12 additions & 10 deletions catwalk/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,8 @@ def _run_loglikelihood_tokens(
for field_name, tensors in unpadded_batch.items()
}

batch_logits = log_softmax(model(**padded_batch)[0], dim=-1)
with torch.no_grad():
batch_logits = log_softmax(model(**padded_batch)[0], dim=-1)
z = zip(
batch_of_indices,
batch_logits,
Expand Down Expand Up @@ -642,8 +643,8 @@ def _run_greedy_until(
if isinstance(untils, str):
untils = [untils]
# if any of the stop phrases are single tokens we can use that for early termination
primary_until = None
for tokenized_until in tokenizer(untils)["input_ids"]:
primary_until = tokenizer.eos_token_id
for tokenized_until in tokenizer(untils, add_special_tokens=False)["input_ids"]:
if len(tokenized_until) == 1:
primary_until = tokenized_until[0]

Expand All @@ -652,13 +653,14 @@ def _run_greedy_until(
[tokenized_context[max_gen_toks - model_max_length :]]
).to(model.device)

full_text_tensor = model.generate(
context_tensor,
max_length=context_tensor.shape[1] + max_gen_toks,
eos_token_id=primary_until,
do_sample=False,
pad_token_id=primary_until, # temporary hack to suppress irrelevant warning until batch processing is added
)
with torch.no_grad():
full_text_tensor = model.generate(
context_tensor,
max_length=context_tensor.shape[1] + max_gen_toks,
eos_token_id=primary_until,
do_sample=False,
pad_token_id=primary_until, # temporary hack to suppress irrelevant warning until batch processing is added
)
continuation_tensor = full_text_tensor[0, context_tensor.shape[1] :]
continuation = tokenizer.decode(continuation_tensor.tolist())
raw_continuation = continuation
Expand Down

0 comments on commit c3eb82e

Please sign in to comment.