KEMBAR78
Add a function to torch.nested to create nested tensors from a buffer and sizes · Issue #112509 · pytorch/pytorch · GitHub
Skip to content

Add a function to torch.nested to create nested tensors from a buffer and sizes #112509

@fleonce

Description

@fleonce

🚀 The feature, motivation and pitch

There currently exists no function in the torch.nested namespace to create nested Tensors from buffers combined with sizes and (optionally I think) strides and storage offsets, but the C++ Code already has functions for this kind of operation: wrap_buffer in aten/src/ATen/native/nested/NestedTensorUtils.h:

inline at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_sizes) {
TORCH_CHECK(
buffer.dim() == 1,
"Expected given buffer to be 1dim, but got ",
buffer.dim(),
" instead.");
TORCH_CHECK(
buffer.is_contiguous(), "Expected given buffer to be contiguous.");
return at::detail::make_tensor<NestedTensorImpl>(
std::move(buffer), std::move(nested_sizes));
}
I use nested Tensors to accelerate my Dataset storage/loading, as can be observed using the small example below, so this feature would be greatly appreciated.

I would be willing to open an according PR if that is desired, maybe this feature can also be included for the promotion of nested Tensors to the beta stage (#112398)

import torch
from time import time

# create a random dataset of variable sequence lengths (10k elements)
sizes = torch.randint(32, 256, size=(10000,))
dataset = {
    'input_ids': [torch.randint(32000, size=(i,)) for i in sizes],
    'labels': [torch.randint(32000, size=(i,)) for i in sizes],
}

def measure(name, func):
    s = time()
    func()
    print(f"{name} took {time() - s} seconds")

# save the dataset
measure("store_dict", lambda: torch.save(dataset, "dataset.pt"))
# load the dataset
measure("load_dict", lambda: torch.load("dataset.pt", weights_only=True))

# convert the dataset to nested tensor
def to_nested_tensor_info(ds: dict):
    return {k: torch.nested.nested_tensor(v) for k, v in ds.items()}

def save_nested_tensor(ds: dict):
    nested_ds = to_nested_tensor_info(ds)
    torch.save({
        k: (
            torch.cat(v, dim=-1),  # one large tensor
            nested_ds[k]._nested_tensor_size(),
            nested_ds[k]._nested_tensor_strides(),
            nested_ds[k]._nested_tensor_storage_offsets(),
        ) for k, v in ds.items()
    }, "nested.pt")

def load_nested_tensor():
    tuple_ds = torch.load("nested.pt", weights_only=True)
    return {
        k: torch._nested_view_from_buffer(*v) for k, v in tuple_ds.items()
    }

# save the dataset
measure("store_nested", lambda: save_nested_tensor(dataset))
# load the dataset
measure("load_nested", lambda: print(load_nested_tensor()['input_ids']))

Alternatives

There exists _nested_view_from_buffer, but it is not located in the torch.nested. namespace, so I am unsure about the future for the function at hand.

Additional context

store_dict took 0.5521411895751953 seconds
load_dict took 1.3255178928375244 seconds
store_nested took 0.10043215751647949 seconds
load_nested took 0.009233236312866211 seconds

cc @mruberry @mikaylagawarecki @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032module: serializationIssues related to serialization (e.g., via pickle, or otherwise) of PyTorch objectstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions