KEMBAR78
[MPS] Adding lgamma, digamma, and polygamma implementations by igm503 · Pull Request #106292 · pytorch/pytorch · GitHub
Skip to content

Conversation

@igm503
Copy link
Contributor

@igm503 igm503 commented Jul 31, 2023

Fixes issue mentioned in #77764

e.g. #77764 (comment)

Adds MPS support for the following ops:

  • lgamma
  • mvlgamma
  • digamma
  • polygamma

The lgamma fucntion does not yet have an MPS backend implementation. I've added one using a custom metal kernel (following John D. Cook's c++ implementation of the log gamma function: https://www.johndcook.com/blog/cpp_gamma/). For the backward pass op, I've added a digamma kernel that follows the cpu+cuda digamma implementation, and for the backward pass of the digamma op, I've added a polygamma + trigamma kernel following, again, the cpu+cuda implementations.

NOTE:

The cpu implementation of the polygamma function incorrectly (as far as I can tell) outputs a finite number for order = 1 and x in the negative integers. The mps implementation correctly outputs infinite. (see #106692)

The polygamma tests currently don't pass because of the error in the cpu+cuda kernels, but also because there are smallish discrepancies near the negative integers between the cpu+cuda and the mps polygamma and trigamma kernels. I'm not sure exactly why this is, but let me know if the discrepancies are too big.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106292

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 8 Unrelated Failures

As of commit cc34d98 with merge base 703cdd7 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Jul 31, 2023
@igm503 igm503 marked this pull request as ready for review August 7, 2023 06:29
@igm503 igm503 requested a review from kulinseth as a code owner August 7, 2023 06:29
@igm503 igm503 changed the title added log-gamma function kernel for mps backend [MPS] lgamma, digamma, and polygamma implementations Aug 7, 2023
@igm503 igm503 changed the title [MPS] lgamma, digamma, and polygamma implementations [MPS] Adding lgamma, digamma, and polygamma implementations Aug 7, 2023
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 7, 2023
@igm503
Copy link
Contributor Author

igm503 commented Aug 26, 2023

@kulinseth Any chance you can give this a look and advise about whether the test failures are a problem?

@kulinseth
Copy link
Collaborator

=================================== FAILURES ===================================
______________ TestFallbackWarning.test_error_on_not_implemented _______________
Traceback (most recent call last):
  File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_mps.py", line 10438, in test_error_on_not_implemented
    fn(*args, **kwargs)
  File "/Users/ec2-user/runner/_work/_temp/conda_environment_5983090866/lib/python3.9/unittest/case.py", line 226, in __exit__
    self._raiseFailure("{} not raised".format(exc_name))
  File "/Users/ec2-user/runner/_work/_temp/conda_environment_5983090866/lib/python3.9/unittest/case.py", line 163, in _raiseFailure
    raise self.test_case.failureException(msg)
AssertionError: NotImplementedError not raised

This issue seems unrelated to the PR. Can you @igm503 please rebase the PR?

@igm503 igm503 force-pushed the lgamma branch 2 times, most recently from 048b88a to 3938014 Compare September 2, 2023 20:04
@igm503 igm503 closed this Sep 2, 2023
@igm503 igm503 deleted the lgamma branch September 2, 2023 20:08
@igm503 igm503 restored the lgamma branch September 2, 2023 20:14
@igm503 igm503 reopened this Sep 2, 2023
@kulinseth
Copy link
Collaborator

@kulinseth Any chance you can give this a look and advise about whether the test failures are a problem?

@igm503 the assertion is coming from not implemented test . Can you check if lgamma tests are not in that category class in test_mps .

@igm503
Copy link
Contributor Author

igm503 commented Sep 3, 2023

@kulinseth I've fixed the assertion error by swapping another not-yet-implemented op for lgamma in the not_implemented test.

@igm503
Copy link
Contributor Author

igm503 commented Sep 6, 2023

@kulinseth So, at least as I'm typing this, the test errors are now those that I mentioned in the pull request body: in some cases, they're precision issues, but in other cases, I think the cpu implementation is incorrect.

@kulinseth
Copy link
Collaborator

@kulinseth So, at least as I'm typing this, the test errors are now those that I mentioned in the pull request body: in some cases, they're precision issues, but in other cases, I think the cpu implementation is incorrect.

@igm503 , I see, we can add these tests to XFAILLIST here.

@igm503
Copy link
Contributor Author

igm503 commented Sep 7, 2023

@kulinseth So, at least as I'm typing this, the test errors are now those that I mentioned in the pull request body: in some cases, they're precision issues, but in other cases, I think the cpu implementation is incorrect.

@igm503 , I see, we can add these tests to XFAILLIST here.

The tests now pass on the macos 13 builds.

@kulinseth However, since there are precision issues with test_output_grad_match_polygamma_polygamma_n_0_cpu_float32 on macos 12 as well, where should I put that exception? I scanned the different XFAILLISTs, and I don't see a clear place for it. Of course, I could put it in the pre-13 XFAIL list, but that would make it seem like it's fixed for >13, which it isn't.

@igm503
Copy link
Contributor Author

igm503 commented Sep 11, 2023

@kulinseth I went ahead and added the failing tests to the MACOS_BEFORE_13_3_XFAILLIST as well. Let me know if there's a more appropriate place to put them.

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

@igm503
Copy link
Contributor Author

igm503 commented Sep 12, 2023

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64 / test (default, 3, 3, macos-m1-12)

Details for Dev Infra team Raised by workflow job

@igm503
Copy link
Contributor Author

igm503 commented Sep 12, 2023

@pytorchbot merge -i

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants