-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
high prioritymodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032triage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Running the following script should reproduce the problem. Applying a simple Linear layer to a nested tensor results in cached memory explosion if the nested tensor is not manually flattened.
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, in_dim=256, embed_dim=1024, upsample_rate=4, explode=False):
super().__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.upsample_rate = upsample_rate
self.explode = explode
self.fc_upsample = nn.Linear(in_dim, upsample_rate * embed_dim)
self.fc_outer = nn.Linear(embed_dim, in_dim)
def forward(self, in_tensor):
### Upsample
tensor = self.fc_upsample(in_tensor)
### ***BUG*** Reshaping nested tensor directly results in exploding cached memory usage (Why???)
if self.explode:
tensor = tensor.view(tensor.size(0), tensor.size(1), self.upsample_rate, self.embed_dim)
tensor = self.fc_outer(tensor).flatten(-2)
return tensor
### ***WORK AROUND*** This indirect reshaping results in fixed cached memory usage
else:
tensor = tensor.values()
tensor = tensor.view(tensor.size(0), self.upsample_rate, self.embed_dim)
tensor = self.fc_outer(tensor).flatten(-2)
return torch.nested.nested_tensor_from_jagged(tensor, offsets=in_tensor.offsets())
### Setup
in_dim = 512
embed_dim = 512
max_seq_len = 100
device = 'cuda:0'
### Change explode to False to observe stable behavior
net = MLP(in_dim=in_dim, embed_dim=embed_dim, upsample_rate=4, explode=True).to(device)
# net = MLP(in_dim=in_dim, embed_dim=embed_dim, upsample_rate=4, explode=False).to(device)
optim = torch.optim.Adam(net.parameters())
### Train
num_steps = 1000
batch_size = 64
for step in range(num_steps):
### Forward, backward, optimize
optim.zero_grad()
in_tensor = [torch.rand(size=[seq_len, in_dim]) for seq_len in torch.randint(max_seq_len, size=(batch_size,))]
in_tensor = torch.nested.as_nested_tensor(in_tensor, layout=torch.jagged).to(device)
loss = net(in_tensor).values().pow(2).mean() ### Dummy weight decay loss
loss.backward()
optim.step()
### Record average parameter magnitude
with torch.no_grad():
param_mag = 0.
count = 0
for param in net.parameters():
param_mag = param_mag + param.abs().sum().item()
count = count + param.numel()
param_mag /= count
### Print memory usage
max_allocated = torch.cuda.max_memory_allocated(0) // 1024 ** 3
max_reserved = torch.cuda.max_memory_reserved(0) // 1024 ** 3
print(f'>>> Step={step}, VRAM={max_allocated:.3f}<{max_reserved:.3f}GiB, param_mag={param_mag:.3f}')Output with explode=True
# >>> Step=0, VRAM=2.000<3.000GiB, param_mag=0.022
# >>> Step=1, VRAM=3.000<6.000GiB, param_mag=0.022
# >>> Step=2, VRAM=3.000<6.000GiB, param_mag=0.022
# >>> Step=3, VRAM=3.000<6.000GiB, param_mag=0.022
# >>> Step=4, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=5, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=6, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=7, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=8, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=9, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=10, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=11, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=12, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=13, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=14, VRAM=3.000<9.000GiB, param_mag=0.022
# >>> Step=15, VRAM=3.000<13.000GiB, param_mag=0.022
# >>> Step=16, VRAM=3.000<13.000GiB, param_mag=0.022
# >>> Step=17, VRAM=3.000<13.000GiB, param_mag=0.022
# >>> Step=18, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=19, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=20, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=21, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=22, VRAM=3.000<13.000GiB, param_mag=0.021
# >>> Step=23, VRAM=3.000<16.000GiB, param_mag=0.021
# >>> Step=24, VRAM=3.000<16.000GiB, param_mag=0.021
# >>> Step=25, VRAM=3.000<16.000GiB, param_mag=0.021
# >>> Step=26, VRAM=3.000<20.000GiB, param_mag=0.021
# >>> Step=27, VRAM=3.000<20.000GiB, param_mag=0.021
# >>> Step=28, VRAM=3.000<20.000GiB, param_mag=0.021
# >>> Step=29, VRAM=3.000<20.000GiB, param_mag=0.021
# >>> Step=30, VRAM=3.000<20.000GiB, param_mag=0.021
Output with explode=False
# >>> Step=0, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=1, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=2, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=3, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=4, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=5, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=6, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=7, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=8, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=9, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=10, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=11, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=12, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=13, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=14, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=15, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=16, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=17, VRAM=0.000<0.000GiB, param_mag=0.022
# >>> Step=18, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=19, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=20, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=21, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=22, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=23, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=24, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=25, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=26, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=27, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=28, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=29, VRAM=0.000<0.000GiB, param_mag=0.021
# >>> Step=30, VRAM=0.000<0.000GiB, param_mag=0.021
Versions
Nightly pytorch from pip:
torch 2.6.0.dev20241028+cu121 pypi_0 pypi
torchaudio 2.5.0.dev20241028+cu121 pypi_0 pypi
torchvision 0.20.0.dev20241028+cu121 pypi_0 pypi
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ
vadimkantorov
Metadata
Metadata
Assignees
Labels
high prioritymodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032triage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module