KEMBAR78
[Inductor] Generalize inductor triton backend device agnostic by guangyey · Pull Request #109486 · pytorch/pytorch · GitHub
Skip to content

Conversation

@guangyey
Copy link
Collaborator

@guangyey guangyey commented Sep 18, 2023

Motivation

@jansel As discussed before, we expected to generalize some cuda-specific code. This can make inductor more friendly to third-party backend so that we can leverage inductor code as much as possible.

Solution

To implement this, we give a solution to introduce device runtime abstraction. We wrapper them inside DeviceInterface and use register_interface_for_device to register each kind of device to inductor. Then use get_interface_for_device to fetch the corresponding runtime from device type. Then usage is like this:

device_interface = get_interface_for_device("xpu")
device_interface .is_available() # to check if XPU is available
device_interface .device_count() # to check how much XPU device is available

The DeviceInterface is a simple abstraction, which enables third-party backends that implement CUDA-like semantics to be integrated with inductor. This can prevent third-party backend from using monkey patch to override some utility functions, like decode_device that is hard-coded with CUDA.

Additional Context

The main code change:

  • To leverage AsyncCompile, make it device-agnostic
  • Avoid monkey patches, make some utility functions device-agnostic

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @gujinghui @arthuryuan1987

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 18, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109486

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f3427f2 with merge base dee1009 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@fengyuan14
Copy link
Collaborator

Hi, @jansel,
It's a PR about runtime code unification in Inductor Scheduling and TritonCodeGen components by providing a general runtime interface. Then, as non-CUDA backend, we can share these codes.
Please review. Thanks.

Comment on lines 11 to 12
_compile_worker_device_properties: Dict[str, Any] = {}
_compile_worker_current_devices: Dict[str, int] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems pretty similar to RuntimeInterface in runtime.py above. Can we just merge both files into a single subclass?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a try. But I met an issue. In my understanding, DeviceInterface is an interface of device runtime. So the API set_device and get_device_properties should be general. set_compiler_worker_current_device, current_device, and get_device_properties in device_properties.py are just used for compiling triton kernel in multi-processing.
Actually, get_device_properties in device_properties.py has a potential issue if the user uses set_device instead of set_compiler_worker_current_device to switch the device index. The reason is current_device in device_properties.py always fetch the cache from _compile_worker_current_devices instead of calling torch.cuda.current_device() in runtime.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I would like to know if it is reasonable to keep device_properties.py. If you prefer to merge them into DeviceInterface, is it ok to add two specific methods set_device_for_cache and get_device_properties_for_cache to 'DeviceInterfaceto distinguish fromset_deviceandget_device_properties`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to merge them unless there is some technical reason not to.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@guangyey
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 21, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@guangyey
Copy link
Collaborator Author

@EikanWang
Copy link
Collaborator

@guangyey , please check the failed cases locally. It might be caused by this PR.

@EikanWang EikanWang added the intel This tag is for PR from Intel label Sep 21, 2023
@guangyey
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel Merged module: dynamo module: inductor open source release notes: foreach_frontend release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants