-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: xpuIntel XPU related issuesIntel XPU related issuestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
With f063027 I am trying to follow https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups to check that torch.compile works for XPU backend. I eventually change s/cuda/xpu/ in the example - see overall code below. Running example give the following error:
Traceback (most recent call last):
File "/home/dvrogozh/examples/torch/tutorials/ex4.py", line 39, in <module>
print("eager:", timed(lambda: model(inp))[1])
File "/home/dvrogozh/examples/torch/tutorials/ex4.py", line 13, in timed
return result, start.elapsed_time(end) / 1000
File "/home/dvrogozh/git/pytorch/pytorch/torch/xpu/streams.py", line 153, in elapsed_time
return super().elapsed_time(end_event)
NotImplementedError: elapsed_time is not supported by XPUEvent.
Error eventually is coming from here:
pytorch/aten/src/ATen/xpu/XPUEvent.h
Lines 143 to 144 in 2ff98bc
| TORCH_CHECK_NOT_IMPLEMENTED( | |
| false, "elapsed_time is not supported by XPUEvent."); |
Example script below (aka ex4.py):
import torch
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.xpu.Event(enable_timing=True)
end = torch.xpu.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.xpu.synchronize()
return result, start.elapsed_time(end) / 1000
# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
return (
torch.randn(b, 3, 128, 128).to(torch.float32).xpu(),
torch.randint(1000, (b,)).xpu(),
)
N_ITERS = 10
from torchvision.models import densenet121
def init_model():
return densenet121().to(torch.float32).xpu()
model = init_model()
# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()
model_opt = torch.compile(model, mode="reduce-overhead")
inp = generate_data(16)[0]
with torch.no_grad():
print("eager:", timed(lambda: model(inp))[1])
print("compile:", timed(lambda: model_opt(inp))[1])
CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 @vlad-penkin
Metadata
Metadata
Assignees
Labels
module: xpuIntel XPU related issuesIntel XPU related issuestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Type
Projects
Status
Done