-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
internal ramp-up taskTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
Currently inplacing logic only extends to Pointwise uses. We should extend this to reductions.
For example, matmul_output of LayerNorm can be inplaced below:
import torch
torch.set_default_device("cuda")
torch.set_grad_enabled(False)
batch_size = 32
seq_length = 50
hidden_size = 768
inp = torch.randn(batch_size, seq_length, hidden_size)
weight = torch.randn(hidden_size, hidden_size)
layer_norm = torch.nn.LayerNorm(hidden_size)
@torch.compile()
def foo(inp, weight):
matmul_output = inp @ weight
final_output = layer_norm(matmul_output)
return final_output
foo(inp, weight)
This can help with both perf and memory.
Alternatives
No response
Additional context
No response
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
Metadata
Metadata
Assignees
Labels
internal ramp-up taskTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module