-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🚀 The feature, motivation and pitch
There currently exists no function in the torch.nested namespace to create nested Tensors from buffers combined with sizes and (optionally I think) strides and storage offsets, but the C++ Code already has functions for this kind of operation: wrap_buffer in aten/src/ATen/native/nested/NestedTensorUtils.h:
pytorch/aten/src/ATen/native/nested/NestedTensorUtils.h
Lines 35 to 45 in 0402492
| inline at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_sizes) { | |
| TORCH_CHECK( | |
| buffer.dim() == 1, | |
| "Expected given buffer to be 1dim, but got ", | |
| buffer.dim(), | |
| " instead."); | |
| TORCH_CHECK( | |
| buffer.is_contiguous(), "Expected given buffer to be contiguous."); | |
| return at::detail::make_tensor<NestedTensorImpl>( | |
| std::move(buffer), std::move(nested_sizes)); | |
| } |
I would be willing to open an according PR if that is desired, maybe this feature can also be included for the promotion of nested Tensors to the beta stage (#112398)
import torch
from time import time
# create a random dataset of variable sequence lengths (10k elements)
sizes = torch.randint(32, 256, size=(10000,))
dataset = {
'input_ids': [torch.randint(32000, size=(i,)) for i in sizes],
'labels': [torch.randint(32000, size=(i,)) for i in sizes],
}
def measure(name, func):
s = time()
func()
print(f"{name} took {time() - s} seconds")
# save the dataset
measure("store_dict", lambda: torch.save(dataset, "dataset.pt"))
# load the dataset
measure("load_dict", lambda: torch.load("dataset.pt", weights_only=True))
# convert the dataset to nested tensor
def to_nested_tensor_info(ds: dict):
return {k: torch.nested.nested_tensor(v) for k, v in ds.items()}
def save_nested_tensor(ds: dict):
nested_ds = to_nested_tensor_info(ds)
torch.save({
k: (
torch.cat(v, dim=-1), # one large tensor
nested_ds[k]._nested_tensor_size(),
nested_ds[k]._nested_tensor_strides(),
nested_ds[k]._nested_tensor_storage_offsets(),
) for k, v in ds.items()
}, "nested.pt")
def load_nested_tensor():
tuple_ds = torch.load("nested.pt", weights_only=True)
return {
k: torch._nested_view_from_buffer(*v) for k, v in tuple_ds.items()
}
# save the dataset
measure("store_nested", lambda: save_nested_tensor(dataset))
# load the dataset
measure("load_nested", lambda: print(load_nested_tensor()['input_ids']))Alternatives
There exists _nested_view_from_buffer, but it is not located in the torch.nested. namespace, so I am unsure about the future for the function at hand.
Additional context
store_dict took 0.5521411895751953 seconds
load_dict took 1.3255178928375244 seconds
store_nested took 0.10043215751647949 seconds
load_nested took 0.009233236312866211 seconds
cc @mruberry @mikaylagawarecki @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer