KEMBAR78
Fix underflow issue with dirichlet sample by neerajprad · Pull Request #17488 · pytorch/pytorch · GitHub
Skip to content

Conversation

@neerajprad
Copy link
Contributor

@neerajprad neerajprad commented Feb 26, 2019

Addresses #15738, using @fritzo's suggestion. This adds a torch._sample_dirichlet method in Distributions.cpp and Distributions.cu.

  • For CPU, this leads to no perf hit since all we do is to promote the alpha to double when getting the gamma samples (the gamma sampler anyways uses accscalar_t(double for CPU)) and cast it back to float32 on return.
  • I have added an analogous method for CUDA as well, but the default sampler for CUDA uses scalar_t for efficiency, so I have kept it as that. With this, I do not see the bias towards 1 as reported in Beta Distribution values wrong for a=b---> 0 #15738 with float32, but there is a spurious mode at 0.5, as would be expected. Users would need to explicitly use float64 for GPU to not see the spurious mode at 0.5. (EDIT: see note below, it appears that the bias issue is still there for certain builds).

Added some tests and checked that there is no perf regression. My experience with C++ is very limited, so apologies in advance if I missed something basic. cc. @ailzhang, @fritzo, @fmassa

BaseSampler<accscalar_t, decltype(normal_lambda)> standard_normal(normal_lambda);
auto sample = sample_gamma<scalar_t, accscalar_t, decltype(uniform_lambda), decltype(normal_lambda)>(alpha, standard_uniform, standard_normal);
auto min_value = std::numeric_limits<scalar_t>::lowest();
auto min_value = std::numeric_limits<scalar_t>::min();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a mistake - since lowest would be a negative value. I was seeing 0s on CUDA.

@ezyang
Copy link
Contributor

ezyang commented Feb 26, 2019

Looks like test failure wasn't totally fixed

Feb 26 08:05:58 ======================================================================
Feb 26 08:05:58 FAIL: test_beta_wrt_beta (__main__.TestRsample)
Feb 26 08:05:58 ----------------------------------------------------------------------
Feb 26 08:05:58 Traceback (most recent call last):
Feb 26 08:05:58   File "test_distributions.py", line 2625, in test_beta_wrt_beta
Feb 26 08:05:58     'at x = {!r}'.format(x[rel_error.argmax()]),
Feb 26 08:05:58 AssertionError: 4.3483026956891986e+23 not less than 0.005 : Bad gradient dx/dcon0 for x ~ Beta(0.01, 0.01)
Feb 26 08:05:58 x [1.1754944e-38 1.1754944e-38 1.1754944e-38 1.1754944e-38 1.1754944e-38
Feb 26 08:05:58  1.1754944e-38 3.0752082e-31 1.7097079e-26 8.5553456e-06 7.9280101e-03
Feb 26 08:05:58  1.0000000e+00 1.0000000e+00 1.0000000e+00 1.0000000e+00 1.0000000e+00
Feb 26 08:05:58  1.0000000e+00 1.0000000e+00 1.0000000e+00 1.0000000e+00 1.0000000e+00]
Feb 26 08:05:58 expected [-5.87948525e-35 -5.87948525e-35 -5.87948525e-35 -5.87948525e-35
Feb 26 08:05:58  -5.87948525e-35 -5.87948525e-35 -1.53813084e-27 -8.55146827e-23
Feb 26 08:05:58  -4.27910236e-02 -3.93453946e+01 -0.00000000e+00 -0.00000000e+00
Feb 26 08:05:58  -0.00000000e+00 -0.00000000e+00 -0.00000000e+00 -0.00000000e+00
Feb 26 08:05:58  -0.00000000e+00 -0.00000000e+00 -0.00000000e+00 -0.00000000e+00]
Feb 26 08:05:58 actual [-5.8793647e-35 -5.8793647e-35 -5.8793647e-35 -5.8793647e-35
Feb 26 08:05:58  -5.8793647e-35 -5.8793647e-35 -1.5380993e-27 -8.5512932e-23
Feb 26 08:05:58  -4.2790145e-02 -3.9344589e+01 -4.0541446e-34 -1.9881570e-08
Feb 26 08:05:58  -3.5506812e-34 -1.6141941e-34 -4.3483027e-07 -1.6141941e-34
Feb 26 08:05:58  -1.6141941e-34 -6.0947656e-22 -1.6141941e-34 -3.9208717e-32]
Feb 26 08:05:58 rel error [ 1.20539195e-09  1.20539195e-09  1.20539195e-09  1.20539195e-09
Feb 26 08:05:58   1.20539195e-09  1.20539195e-09 -2.05402476e-05 -2.04689217e-05
Feb 26 08:05:58  -2.05395321e-05 -2.04679638e-05  4.05414456e-04  1.98815702e+22
Feb 26 08:05:58   3.55068125e-04  1.61419412e-04  4.34830270e+23  1.61419412e-04
Feb 26 08:05:58   1.61419412e-04  6.09476561e+08  1.61419412e-04  3.92087169e-02]
Feb 26 08:05:58 max error 4.3483026956891986e+23
Feb 26 08:05:58 at x = 1.0
Feb 26 08:05:58 
Feb 26 08:05:58 ----------------------------------------------------------------------
Feb 26 08:05:58 Ran 192 tests in 24.281s

@neerajprad
Copy link
Contributor Author

Looks like test failure wasn't totally fixed

Thanks for pointing out. I couldn't see this on mac, but I am getting this error on my linux workstation. Will debug and push a fix for this.

@neerajprad
Copy link
Contributor Author

neerajprad commented Feb 27, 2019

That was a legitimate error 1.0f - std::numeric_limits<float>::min() is still == 1.0f, and not the smallest number representable in float that is less than 1.0f. I am currently doing 1 - std::numeric_limits<float>::epsilon() on CUDA due to the lack of std::nexttoward.

@fritzo
Copy link
Collaborator

fritzo commented Feb 27, 2019

looks like the latest test failures are real:

02:47:15 ======================================================================
02:47:15 FAIL: test_beta_underflow_gpu (test_distributions.TestDistributions)
02:47:15 ----------------------------------------------------------------------
02:47:15 Traceback (most recent call last):
02:47:15   File "/var/lib/jenkins/workspace/test/common_utils.py", line 296, in wrapper
02:47:15     method(*args, **kwargs)
02:47:15   File "/var/lib/jenkins/workspace/test/test_distributions.py", line 2240, in test_beta_underflow_gpu
02:47:15     self.assertEqual(frac_zeros, 0.5, 0.05)
02:47:15   File "/var/lib/jenkins/workspace/test/common_utils.py", line 453, in assertEqual
02:47:15     super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
02:47:15 AssertionError: 0.10202 not less than or equal to 0.05 : 
02:47:15 
02:47:15 ----------------------------------------------------------------------

@neerajprad
Copy link
Contributor Author

neerajprad commented Feb 28, 2019

looks like the latest test failures are real

It seems like we might have a bias in the number of samples towards 0s and 1s for Beta(0.01, 0.01) for these builds on the GPU. While it should be 50:50, it is 60:40 or 40:60 (still much less than what was reported in #15738). I don't see this imbalance locally on ubuntu 16.04 with CUDA 10 though.

I have decreased the precision, but just noting that the imbalance issue reported in #15738 for GPU isn't fixed with this PR and still exists for certain builds. I'll need more help debugging that, because I can't seem to reproduce it.

@neerajprad
Copy link
Contributor Author

The recent failures are unrelated, and related to the build.

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2019

@fritzo I'll merge this on your say so.

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks for fixing this!

@ezyang I am happy with the tests and math, but my C++ and CUDA is rusty. I see no errors but I defer to you for C++/CUDA nitpicking.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Mar 12, 2019

@li-roy We can't land this patch until BC is restored:

caffe2/aten/src/ATen/native/Distributions.cpp:233:3: error: statement requires expression of integer type ('at::Type' invalid)
  AT_DISPATCH_FLOATING_TYPES(ret.type(), "dirichlet", [&] {
  ^                          ~~~~~~~~~~
caffe2/aten/src/ATen/native/Distributions.cpp:233:3: error: no matching function for call to 'toString'
  AT_DISPATCH_FLOATING_TYPES(ret.type(), "dirichlet", [&] {
  ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

@neerajprad You don't need to do anything, but this diff no longer applies on master due to a BC-breaking change to AT_DISPATCH_FLOATING_TYPES

@neerajprad
Copy link
Contributor Author

@neerajprad You don't need to do anything, but this diff no longer applies on master due to a BC-breaking change to AT_DISPATCH_FLOATING_TYPES

I see that this was recently changed in #17527. Should I push a fix for this?

@li-roy
Copy link
Contributor

li-roy commented Mar 13, 2019

@ezyang in the meantime, this can just be changed to ret.scalar_type() i think

@ezyang
Copy link
Contributor

ezyang commented Mar 18, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Mar 19, 2019

Good to go.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Mar 19, 2019
Summary:
Addresses #15738, using fritzo's suggestion. This adds a `torch._sample_dirichlet` method in `Distributions.cpp` and `Distributions.cu`.
 - For CPU, this leads to no perf hit since all we do is to promote the `alpha` to double when getting the gamma samples (the gamma sampler anyways uses `accscalar_t`(double for CPU)) and cast it back to float32 on return.
 - I have added an analogous method for CUDA as well, but the default sampler for CUDA uses scalar_t for efficiency, so I have kept it as that. With this, I do not see the bias towards 1 as reported in #15738 with `float32`, but there is a spurious mode at 0.5, as would be expected. Users would need to explicitly use `float64` for GPU to not see the spurious mode at 0.5. (EDIT: see note below, it appears that the bias issue is still there for certain builds).

Added some tests and checked that there is no perf regression. My experience with C++ is very limited, so apologies in advance if I missed something basic. cc. ailzhang, fritzo, fmassa
Pull Request resolved: pytorch/pytorch#17488

Differential Revision: D14410301

Pulled By: ezyang

fbshipit-source-id: 62b2f694b4642685eab06db96d74ce28e05c3992
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants