From b5b31791a76cf37ac3856d0394fa3eb5502217a6 Mon Sep 17 00:00:00 2001 From: jinminxi104 Date: Mon, 25 Nov 2024 20:30:58 +0800 Subject: [PATCH] Optimize update_step_ctx on Ascend (#2804) * opt update_ctx for ascend * fix lint --- .../backends/dlinfer/ascend/op_backend.py | 61 +++++++++++-------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 79e528836..b6f544510 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -71,31 +71,42 @@ def get_total_slots(): max_q_seq_len = max(q_seqlens_list) max_kv_seq_len = max(kv_seqlens_list) - for i in range(step_context.q_start_loc.size(0)): - q_seq_len = q_seqlens_list[i] - kv_seq_len = kv_seqlens_list[i] - - # collect kv start indices. - history_length = kv_seq_len - q_seq_len - total_slots = get_total_slots() - slot_tables = total_slots[step_context.block_offsets[i]].view(-1) - slots = slot_tables[history_length:kv_seq_len] - kv_start_indices.append(slots) - - # collect attention mask of paged_prefill attention stage. - if not (step_context.is_decoding or is_unpaged_prefill): - single_attention_mask = torch.logical_not( - torch.tril( - torch.ones(q_seq_len, - step_context.block_offsets.shape[1] * - block_size, - dtype=torch.bool, - device=step_context.block_offsets.device), - diagonal=kv_seq_len - q_seq_len, - )) - attention_mask.append(single_attention_mask) - - kv_start_indices = torch.cat(kv_start_indices) + if step_context.is_decoding: + # collect kv_start_indices without using a for-loop, + # (fill kv-cache for just ONE token during the decoding phase) + idx = (step_context.kv_seqlens - 1) % block_size + block_num = (step_context.kv_seqlens - 1) // block_size + last_block = step_context.block_offsets.gather( + 1, block_num.view(-1, 1)).view(-1) + kv_start_indices = last_block * block_size + idx + else: + for i in range(step_context.q_start_loc.size(0)): + q_seq_len = q_seqlens_list[i] + kv_seq_len = kv_seqlens_list[i] + + # collect kv start indices during the prefill phase. + history_length = kv_seq_len - q_seq_len + total_slots = get_total_slots() + slot_tables = total_slots[step_context.block_offsets[i]].view( + -1) + slots = slot_tables[history_length:kv_seq_len] + kv_start_indices.append(slots) + + # collect attention mask of paged_prefill attention stage. + if not is_unpaged_prefill: + single_attention_mask = torch.logical_not( + torch.tril( + torch.ones( + q_seq_len, + step_context.block_offsets.shape[1] * + block_size, + dtype=torch.bool, + device=step_context.block_offsets.device), + diagonal=kv_seq_len - q_seq_len, + )) + attention_mask.append(single_attention_mask) + + kv_start_indices = torch.cat(kv_start_indices) if step_context.is_decoding: # prepare some params of paged_decode attention stage.