-
Notifications
You must be signed in to change notification settings - Fork 1
Cached linear SSM + causal conv support with Bamba demo #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
7f9e189 to
1242bbb
Compare
| groups: int = 1, | ||
| padding_mode: str = "zeros", | ||
| ) -> torch.Tensor: | ||
| assert padding_mode == "zeros", "padding_mode must be zeros" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems unused?
To stay close to conv1d signature.
| - pages_per_seq: [ps_0, ps_1, ..., ps_{b-1}] where ps_i is the number of pages allocated for | ||
| sequence i. Note that, for example, cache_loc[p_0:p_1] will correspond to the pages associated | ||
| with sequence 1 in the batch. | ||
| - slot_idx: [s_0, s_1, ..., s_{b-1}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is a slot?
Sequence slot from the request object. UUID in the range [0, max_batch_size) assigned by the runtime.
Paged attention doesn't care about the sequence mapping; it only cares about which pages hold cache for a particular sequence.
For SSM, there's no notion of a page; you need the whole state.
| ) | ||
|
|
||
|
|
||
| def _segment_sum(input_tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should ref where these come from?
| y = y[:, :seq_len, :, :] | ||
| y = y.reshape(batch_size, seq_len, num_heads, head_dim) | ||
|
|
||
| return y, ssm_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might as well omit ssm_state from return; I'd added it originally to update the cache in the caller.
| updated_cache = conv_state_cache.roll(shifts=-1, dims=-1) | ||
| # [B, T=1, C] -> [B, C] | ||
| new_sample_bc = input.transpose(1, 2)[..., 0] | ||
| updated_cache[:, :, -1] = new_sample_bc.to(updated_cache.dtype).to(updated_cache.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, when would they end up on different devices?
Probably cursor.
| end_i = start_i + length_i | ||
|
|
||
| mask_i = (flat_idx >= start_i.to(torch.long)) & (flat_idx < end_i.to(torch.long)) | ||
| idx_i = torch.nonzero(mask_i, as_tuple=False).squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw this can cause host-device synchronization if mask_i is a CUDA tensor.
| cls, source_attn_node: Node, cache_config: CacheConfig | ||
| ) -> CacheInitializerDict: | ||
| inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"] | ||
| w_fake: torch.Tensor = source_attn_node.args[1].meta["val"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Who is passing these in at runtime?
| # Reference by per-sequence prefill | ||
| y_ref = torch.empty_like(y) | ||
| for i, ln in enumerate(lens): | ||
| st = 0 if i == 0 else lens[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this still be valid if lens had more than 2 elements?
| def mamba_env(): | ||
| device = "cuda" | ||
| dtype = torch.float16 | ||
| atol = 5e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the diffs are actually 0.0, the largest being one of the conv cache's state at 0.003. Might be worth tightening these bounds?
| ) | ||
|
|
||
|
|
||
| @AttentionRegistry.register("torch_causal_conv") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what the config yaml maps to with attn_backend.
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
see title and slack