Skip to content

Commit

Permalink
Optimize update_step_ctx on Ascend (#2804)
Browse files Browse the repository at this point in the history
* opt update_ctx for ascend

* fix lint
  • Loading branch information
jinminxi104 authored Nov 25, 2024
1 parent f13c0f9 commit b5b3179
Showing 1 changed file with 36 additions and 25 deletions.
61 changes: 36 additions & 25 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,42 @@ def get_total_slots():
max_q_seq_len = max(q_seqlens_list)
max_kv_seq_len = max(kv_seqlens_list)

for i in range(step_context.q_start_loc.size(0)):
q_seq_len = q_seqlens_list[i]
kv_seq_len = kv_seqlens_list[i]

# collect kv start indices.
history_length = kv_seq_len - q_seq_len
total_slots = get_total_slots()
slot_tables = total_slots[step_context.block_offsets[i]].view(-1)
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)

# collect attention mask of paged_prefill attention stage.
if not (step_context.is_decoding or is_unpaged_prefill):
single_attention_mask = torch.logical_not(
torch.tril(
torch.ones(q_seq_len,
step_context.block_offsets.shape[1] *
block_size,
dtype=torch.bool,
device=step_context.block_offsets.device),
diagonal=kv_seq_len - q_seq_len,
))
attention_mask.append(single_attention_mask)

kv_start_indices = torch.cat(kv_start_indices)
if step_context.is_decoding:
# collect kv_start_indices without using a for-loop,
# (fill kv-cache for just ONE token during the decoding phase)
idx = (step_context.kv_seqlens - 1) % block_size
block_num = (step_context.kv_seqlens - 1) // block_size
last_block = step_context.block_offsets.gather(
1, block_num.view(-1, 1)).view(-1)
kv_start_indices = last_block * block_size + idx
else:
for i in range(step_context.q_start_loc.size(0)):
q_seq_len = q_seqlens_list[i]
kv_seq_len = kv_seqlens_list[i]

# collect kv start indices during the prefill phase.
history_length = kv_seq_len - q_seq_len
total_slots = get_total_slots()
slot_tables = total_slots[step_context.block_offsets[i]].view(
-1)
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)

# collect attention mask of paged_prefill attention stage.
if not is_unpaged_prefill:
single_attention_mask = torch.logical_not(
torch.tril(
torch.ones(
q_seq_len,
step_context.block_offsets.shape[1] *
block_size,
dtype=torch.bool,
device=step_context.block_offsets.device),
diagonal=kv_seq_len - q_seq_len,
))
attention_mask.append(single_attention_mask)

kv_start_indices = torch.cat(kv_start_indices)

if step_context.is_decoding:
# prepare some params of paged_decode attention stage.
Expand Down

0 comments on commit b5b3179

Please sign in to comment.