KEMBAR78
OpInfo for nn.functional.softmax by krshrimali · Pull Request #62077 · pytorch/pytorch · GitHub
Skip to content

Conversation

@krshrimali
Copy link
Contributor

@krshrimali krshrimali commented Jul 23, 2021

This PR:

  • Adds OpInfo for softmax and nn.functional.softmax (alias).
  • Skip removal for test_jit_alias_remapping test of log_softmax.

Please see pytorch/functorch#78 and #54261.

cc: @mruberry @zou3519 @pmeier

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 23, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit bd9a1d3 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@krshrimali krshrimali requested a review from mruberry July 23, 2021 05:25
@krshrimali krshrimali added the module: testing Issues related to the torch.testing module (not tests) label Jul 23, 2021
@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 23, 2021
@mruberry mruberry requested a review from zou3519 July 24, 2021 10:20
@mruberry
Copy link
Collaborator

This looks pretty good @krshrimali; I have one suggestion (inline) for how to tweak testing the "dtype" kwarg.

@krshrimali krshrimali mentioned this pull request Jul 24, 2021
Copy link
Contributor Author

@krshrimali krshrimali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few updates:

  • This PR now uses sample inputs function of log_softmax for softmax.
  • Added alias for softmax.
  • Code clean-up for some OpInfos, when params are passed with default values (which isn't needed).
  • Skip removal for test_jit_alias_remapping test of log_softmax.

cc: @mruberry @zou3519 (sorry for the ping over the weekend, please review whenever you find time)

@krshrimali
Copy link
Contributor Author

krshrimali commented Jul 26, 2021

Update:

I'm taking a look at the XLA error, if it's something non-trivial - I'll probably remove the scalar input and add an issue related to this error. Also, this is only reproducible on XLA (tested on Google Colab) with scalar tensors:

# Works on CPU
>>> torch.log_softmax(torch.rand((), device='cpu'), dim=0)
tensor(0.)

# Fails on XLA
>>> torch.log_softmax(torch.rand((), device='xla'), dim=0)
ERROR (please see the error below since it's too verbose)

Error:

Error on XLA

RuntimeError                              Traceback (most recent call last)
<ipython-input-4-470607feefbc> in <module>()
----> 1 torch.log_softmax(torch.rand((), device='xla'), dim=0)

RuntimeError: torch_xla/csrc/helpers.cpp:97 : Check failed: min_shape_dim <= dim && dim <= max_shape_dim 
*** Begin stack trace ***
	tensorflow::CurrentStackTrace()
	torch_xla::XlaHelpers::GetCanonicalDimensionIndex(long, long)
	torch_xla::XLATensor::log_softmax(torch_xla::XLATensor const&, long, c10::optional<c10::ScalarType>)
	torch_xla::AtenXlaType::_log_softmax(at::Tensor const&, long, bool)
	c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, long, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, long, bool> >, at::Tensor (at::Tensor const&, long, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long, bool)
	at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, long, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, long, bool)> const&, c10::DispatchKeySet, at::Tensor const&, long, bool) const
	at::redispatch::_log_softmax(c10::DispatchKeySet, at::Tensor const&, long, bool)
	
	
	at::_log_softmax(at::Tensor const&, long, bool)
	at::native::log_softmax(at::Tensor const&, long, c10::optional<c10::ScalarType>)
	
	at::Tensor::log_softmax(long, c10::optional<c10::ScalarType>) const
	
	_PyMethodDef_RawFastCallKeywords
	_PyCFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	PyEval_EvalCode
	
	_PyMethodDef_RawFastCallKeywords
	_PyCFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallDict
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallDict
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyObject_Call_Prepend
	_PyObject_FastCallKeywords
	
	_PyMethodDef_RawFastCallDict
	PyCFunction_Call
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	PyEval_EvalCode
	
	_PyMethodDef_RawFastCallKeywords
	_PyCFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallDict
	
	
	_Py_UnixMain
	__libc_start_main
	_start
*** End stack trace ***
Value out of range (expected to be in range of [0, -1], but got 0)

Also, the error doesn't look message correct: since 0 is in the range [0, -1].

Will update once I've finished building XLA locally to test the macros used. Thanks!

@krshrimali
Copy link
Contributor Author

krshrimali commented Jul 27, 2021

Updates:

PyTorch on 0d tensors with dim=0 doesn't throw an error but on XLA, it does. This needs some discussion, but didn't want to block this PR hence I removed the input ((), 0) case (input shape, dim).

This PR should be ready for review, hopefully, tests should pass. Thanks!

cc: @mruberry @zou3519 @pmeier

Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two nits inline, otherwise LGTM! Thanks @krshrimali.

@krshrimali
Copy link
Contributor Author

Gentle ping @zou3519 @mruberry - if you can take a look whenever you find the time. Thanks! :)

def generator():
for shape, args, kwargs in cases:
yield SampleInput(make_arg(shape), args=args, kwargs=kwargs)
cases = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pathological case, but could we add ((), (0,)) to the list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment @zou3519! I had added this before but the test fails on the XLA device: #62077 (comment). PyTorch on 0d tensors with dim=0 doesn't throw an error but on XLA, it does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, sorry for not seeing that! The action items you proposed (file an issue, leave that case out of the OpInfo) sgtm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @zou3519 ! Post discussion with @mruberry, we thought that having a separate case when the device type isn't XLA so that we don't skip this input for CPU/CUDA devices. I've also filed an issue here: pytorch/xla#3061.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good. I added some comments about some more cases for completeness, after that we should be good to go

@zou3519
Copy link
Contributor

zou3519 commented Jul 29, 2021

Lint is failing: https://github.com/pytorch/pytorch/pull/62077/checks?check_run_id=3191463113

Run (! git --no-pager grep -In '[[:blank:]]$' -- . ':(exclude)**/contrib/**' ':(exclude)**.diff' ':(exclude)third_party' || (echo "The above lines have trailing spaces; please remove them"; false))
torch/testing/_internal/common_methods_invocations.py:7922:        sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), 
The above lines have trailing spaces; please remove them

@krshrimali
Copy link
Contributor Author

Lint is failing: https://github.com/pytorch/pytorch/pull/62077/checks?check_run_id=3191463113

Run (! git --no-pager grep -In '[[:blank:]]$' -- . ':(exclude)**/contrib/**' ':(exclude)**.diff' ':(exclude)third_party' || (echo "The above lines have trailing spaces; please remove them"; false))
torch/testing/_internal/common_methods_invocations.py:7922:        sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), 
The above lines have trailing spaces; please remove them

Thanks, @zou3519 for the pointer, I've fixed it now. Hopefully, the tests should pass :)

@facebook-github-bot
Copy link
Contributor

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

for shape, args, kwargs in cases:
yield SampleInput(make_arg(shape), args=args, kwargs=kwargs)
# PyTorch on XLA throws an error when passed with dim argument for 0d tensor.
# See https://github.com/pytorch/xla/issues/3061 for more details.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice comment

return list(generator())
return [
SampleInput(make_arg(shape), args=dim, kwargs=dict(dtype=torch.float64) if with_dtype else None)
for shape, dim in cases
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: put the for loop first for readability (doesn't have to be changed in this PR)

def sample_inputs_log_softmax(op_info, device, dtype, requires_grad, with_dtype=False, **kwargs):
# Used for both log_softmax and softmax
def sample_inputs_softmax_variant(op_info, device, dtype, requires_grad, with_dtype=False, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"with_dtype" should be kwarg-only

dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
supports_forward_ad=True,
sample_inputs_func=sample_inputs_max_min_binary,),
# `softmax` supports different dtypes based on whether `dtype` argument,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice comment

@mruberry mruberry self-requested a review July 30, 2021 04:43
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice skip removal. Overall looks good. I made a few comments inline to consider in future PRs, no changes needed for this one

@facebook-github-bot
Copy link
Contributor

@zou3519 merged this pull request in 09d10c4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: testing Issues related to the torch.testing module (not tests) open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants