KEMBAR78
Cached linear SSM + causal conv support with Bamba demo by lucaslie · Pull Request #134 · nv-auto-deploy/TensorRT-LLM · GitHub
Skip to content

Conversation

@lucaslie
Copy link
Collaborator

@lucaslie lucaslie commented Sep 19, 2025

see title and slack

@Copilot Copilot AI review requested due to automatic review settings September 19, 2025 22:49
@lucaslie lucaslie changed the base branch from feat/ad_coverage_week1 to feat/ad_linear_attention September 19, 2025 22:49
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
@lucaslie lucaslie changed the title Basic linear caching support Cached linear SSM + causal conv support with Bamba demo Sep 21, 2025
@lucaslie lucaslie self-assigned this Sep 21, 2025
@lucaslie lucaslie merged commit 84ac4cc into feat/ad_linear_attention Sep 21, 2025
3 of 5 checks passed
groups: int = 1,
padding_mode: str = "zeros",
) -> torch.Tensor:
assert padding_mode == "zeros", "padding_mode must be zeros"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unused?

To stay close to conv1d signature.

- pages_per_seq: [ps_0, ps_1, ..., ps_{b-1}] where ps_i is the number of pages allocated for
sequence i. Note that, for example, cache_loc[p_0:p_1] will correspond to the pages associated
with sequence 1 in the batch.
- slot_idx: [s_0, s_1, ..., s_{b-1}]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a slot?

Sequence slot from the request object. UUID in the range [0, max_batch_size) assigned by the runtime.

Paged attention doesn't care about the sequence mapping; it only cares about which pages hold cache for a particular sequence.

For SSM, there's no notion of a page; you need the whole state.

)


def _segment_sum(input_tensor):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ref where these come from?

y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, num_heads, head_dim)

return y, ssm_state
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well omit ssm_state from return; I'd added it originally to update the cache in the caller.

updated_cache = conv_state_cache.roll(shifts=-1, dims=-1)
# [B, T=1, C] -> [B, C]
new_sample_bc = input.transpose(1, 2)[..., 0]
updated_cache[:, :, -1] = new_sample_bc.to(updated_cache.dtype).to(updated_cache.device)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, when would they end up on different devices?

Probably cursor.

end_i = start_i + length_i

mask_i = (flat_idx >= start_i.to(torch.long)) & (flat_idx < end_i.to(torch.long))
idx_i = torch.nonzero(mask_i, as_tuple=False).squeeze(-1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw this can cause host-device synchronization if mask_i is a CUDA tensor.

cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
w_fake: torch.Tensor = source_attn_node.args[1].meta["val"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who is passing these in at runtime?

# Reference by per-sequence prefill
y_ref = torch.empty_like(y)
for i, ln in enumerate(lens):
st = 0 if i == 0 else lens[0]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this still be valid if lens had more than 2 elements?

def mamba_env():
device = "cuda"
dtype = torch.float16
atol = 5e-2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the diffs are actually 0.0, the largest being one of the conv cache's state at 0.003. Might be worth tightening these bounds?

)


@AttentionRegistry.register("torch_causal_conv")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what the config yaml maps to with attn_backend.

lucaslie added a commit that referenced this pull request Sep 29, 2025
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
nvchenghaoz pushed a commit that referenced this pull request Oct 1, 2025
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
nvchenghaoz pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants