- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.4k
make group offloading work with disk/nvme transfers #11682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really cool work at getting this started!
- Can we see some results with more compute heavy model like Wan?
- We probably need to look at some profiles to see if there is overlapping happening here when streams are used with disk-offload (reason: I think there's a blocking operation which prevents this, but not 100% sure)
- re: stark background color difference; Weird, I'll take a look
- can we also benchmark the disk memory usage?
Edit: For the benchmark, I think a fair comparison for all methods would require us to use group offloading on all components instead of just transformer. Maybe the benchmark could be updated to show the memory usages with (1) just transformer, (2) all components
| # Load to CPU, pin, and async copy to device for overlapping transfer and compute | ||
| loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") | ||
| for key, tensor_obj in self.key_to_tensor.items(): | ||
| pinned_tensor = loaded_cpu_tensors[key].pin_memory() | ||
| tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if self.record_stream: | ||
| tensor_obj.data.record_stream(current_stream) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think cleaner approach would be to provide a callable to map_location (assuming we were using torch.load instead of safetensors), which for each tensor can pin and move to device. Do we know if there is a equivalent to passing a callable with safetensors? If not, this is okay too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know if there would be other alternatives to this code path? If not, I think it's better as is. From skimming through the documentation of safetensors, I couldn't find any equivalent of map_location.
| @torch.compiler.disable() | ||
| def offload_(self): | ||
| r"""Offloads the group of modules to the offload_device.""" | ||
| if self.offload_to_disk: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): we probably need to refactor this a bit and break into smaller methods so we don't have to branch and do early-returns every time a new feature is added (we can do refactor once we have everything working, so not urgent)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. Can I do it in an immediate follow-up PR so that it's easier to review?
| self._is_offloaded_to_disk = True | ||
|  | ||
| for tensor_obj in self.tensor_to_key.keys(): | ||
| tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for this to be different from the non-disk-offload counterpart? That is, is there a reason we're not doing buffer.data.to(self.offload_device, non_blocking=self.non_blocking)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, we first free up the memory of the accelerator with:
| key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() | 
However, since we're also optimizing for RAM usage (can be made clearer through documentation I believe), we need to free up the RAM that is holding the tensor data. After the data has been safely written from RAM to the disk, this step replaces the large data tensor in RAM with a memory-less placeholder. This allows the memory to be released.
| @a-r-r-o-w thanks for your comments. I will work on them. 
 I can gather this. Should we gather the CPU and GPU activities through the profiler and export a trace? If you have any references for me to consider, feel free to send over. | 
| 
 I think CPU/GPU activities will measure all the operations in the model, so if you could collect the filtered stream-related operations and related onloading/offloading times, it'll be helpful! You probably already know but for the readers, this will help with gathering the traces for visualization: | 
| @a-r-r-o-w for Wan: 
 | 
| Awesome, thanks for sharing! The numbers look good. re: weird color results with group offloading; i looked into it and seems to only happen with block level (I don't know why yet). I think it could be because of some incorrect/missing synchronization somewhere, so will try to fix. If you use leaf_level, it should produce the same result. | 
| @a-r-r-o-w could I ask for another review at this point? I have also added a simple test to make sure it's working but I can also add a heavier integration test that checks for VRAM and RAM usage. Apart from that, I think only doc is missing. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the current implementation looks good to me for an initial merge after cleanup/docs! I haven't had the chance to look into memory profile yet, but the Wan benchmarks you shared look correct and as expected to me
| @a-r-r-o-w thanks! I think your comments are all addressed now. I have added docs, as well. LMK if there are any other comments address. I am also working further optimizing this. Stay tuned 🤗 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for the awesome feature, LGTM!
| @stevhliu could you review the docs once you have a moment? | 
| Failing tests are unrelated as I have locally run them on a CPU and they pass. Going to merge after this round of CI (have checked with @DN6 offline and we're good to merge). | 
What does this PR do?
Group offloading is a crucial feature to provide a good speed-memory trade-off for large models on consumer hardware. However, since group offloading relies quite a bit on RAM usage, it can be bottlenecked by its availability. As such, for machines where GPU VRAM > available RAM or machines have limited RAM, group offloading can be far from ideal.
This PR takes a stab at supporting disk/NMVe serialization/deserialization inside group offloading so that users can use the secondary memory to onload/offload model params while also benefiting from the overlapping between compute and data transfer.
Below are some numbers I have gathered with this PR:
Code
Quality comparison:

The stark background color difference in regular group offloading exists in the
mainbranch as well. So, I am not sure what is happening there.Group offloading with disk serialization/deserialization works with
torch.compile(), too.This PR is a PoC and hence, it has some things that can be made better. I'd be fine if the PR is completely dropped or if someone else wants to take it over and see it to completion. Otherwise, I am completely fine working on it.
@asomoza I think you will be quite interested in this one.