KEMBAR78
Memory explodes when applying Linear to reshaped nested tensor · Issue #141112 · pytorch/pytorch · GitHub
Skip to content

Memory explodes when applying Linear to reshaped nested tensor #141112

@mahyarkoy

Description

@mahyarkoy

🐛 Describe the bug

Running the following script should reproduce the problem. Applying a simple Linear layer to a nested tensor results in cached memory explosion if the nested tensor is not manually flattened.

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dim=256, embed_dim=1024, upsample_rate=4, explode=False):
        super().__init__()
        self.in_dim = in_dim
        self.embed_dim = embed_dim
        self.upsample_rate = upsample_rate
        self.explode = explode
        self.fc_upsample = nn.Linear(in_dim, upsample_rate * embed_dim)
        self.fc_outer = nn.Linear(embed_dim, in_dim)
    
    def forward(self, in_tensor):
        ### Upsample
        tensor = self.fc_upsample(in_tensor)

        ### ***BUG*** Reshaping nested tensor directly results in exploding cached memory usage (Why???)
        if self.explode:
            tensor = tensor.view(tensor.size(0), tensor.size(1), self.upsample_rate, self.embed_dim)
            tensor = self.fc_outer(tensor).flatten(-2)
            return tensor
        
        ### ***WORK AROUND*** This indirect reshaping results in fixed cached memory usage
        else:
            tensor = tensor.values()
            tensor = tensor.view(tensor.size(0), self.upsample_rate, self.embed_dim)
            tensor = self.fc_outer(tensor).flatten(-2)
            return torch.nested.nested_tensor_from_jagged(tensor, offsets=in_tensor.offsets())

### Setup
in_dim = 512
embed_dim = 512
max_seq_len = 100
device = 'cuda:0'

### Change explode to False to observe stable behavior
net = MLP(in_dim=in_dim, embed_dim=embed_dim, upsample_rate=4, explode=True).to(device)
# net = MLP(in_dim=in_dim, embed_dim=embed_dim, upsample_rate=4, explode=False).to(device)

optim = torch.optim.Adam(net.parameters())

### Train
num_steps = 1000
batch_size = 64
for step in range(num_steps):
    ### Forward, backward, optimize
    optim.zero_grad()
    in_tensor = [torch.rand(size=[seq_len, in_dim]) for seq_len in torch.randint(max_seq_len, size=(batch_size,))]
    in_tensor = torch.nested.as_nested_tensor(in_tensor, layout=torch.jagged).to(device)
    loss = net(in_tensor).values().pow(2).mean() ### Dummy weight decay loss
    loss.backward()
    optim.step()

    ### Record average parameter magnitude
    with torch.no_grad():
        param_mag = 0.
        count = 0
        for param in net.parameters():
            param_mag = param_mag + param.abs().sum().item()
            count = count + param.numel()
        param_mag /= count

    ### Print memory usage
    max_allocated = torch.cuda.max_memory_allocated(0) // 1024 ** 3
    max_reserved = torch.cuda.max_memory_reserved(0) // 1024 ** 3
    print(f'>>> Step={step}, VRAM={max_allocated:.3f}<{max_reserved:.3f}GiB, param_mag={param_mag:.3f}')

Output with explode=True

# >>> Step=0, VRAM=2.000<3.000GiB, param_mag=0.022
# >>> Step=1, VRAM=3.000<6.000GiB, param_mag=0.022
# >>> Step=2, VRAM=3.000<6.000GiB, param_mag=0.022
# >>> Step=3, VRAM=3.000<6.000GiB, param_mag=0.022
# >>> Step=4, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=5, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=6, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=7, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=8, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=9, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=10, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=11, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=12, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=13, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=14, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=15, VRAM=3.000<13.000GiB, param_mag=0.022
# >>> Step=16, VRAM=3.000<13.000GiB, param_mag=0.022
# >>> Step=17, VRAM=3.000<13.000GiB, param_mag=0.022
# >>> Step=18, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=19, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=20, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=21, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=22, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=23, VRAM=3.000<16.000GiB, param_mag=0.021
# >>> Step=24, VRAM=3.000<16.000GiB, param_mag=0.021
# >>> Step=25, VRAM=3.000<16.000GiB, param_mag=0.021
# >>> Step=26, VRAM=3.000<20.000GiB, param_mag=0.021
# >>> Step=27, VRAM=3.000<20.000GiB, param_mag=0.021
# >>> Step=28, VRAM=3.000<20.000GiB, param_mag=0.021
# >>> Step=29, VRAM=3.000<20.000GiB, param_mag=0.021
# >>> Step=30, VRAM=3.000<20.000GiB, param_mag=0.021

Output with explode=False

# >>> Step=0, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=1, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=2, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=3, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=4, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=5, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=6, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=7, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=8, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=9, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=10, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=11, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=12, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=13, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=14, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=15, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=16, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=17, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=18, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=19, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=20, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=21, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=22, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=23, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=24, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=25, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=26, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=27, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=28, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=29, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=30, VRAM=0.000<0.000GiB, param_mag=0.021

Versions

Nightly pytorch from pip:

torch 2.6.0.dev20241028+cu121 pypi_0 pypi
torchaudio 2.5.0.dev20241028+cu121 pypi_0 pypi
torchvision 0.20.0.dev20241028+cu121 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: nestedtensorNestedTensor tag see issue #25032triage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions