KEMBAR78
Fix reduction functions to respect the stride of the output by zou3519 · Pull Request #4995 · pytorch/pytorch · GitHub
Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Feb 1, 2018

Fixes #4974

Consider reduction ops like torch.sum, torch.prod, etc, of the form
reduce_op(output, input, keepdim)

Let output_size be the size of the output with keepdim=False, and output_keepdim_size be the size of the output with keepdim=True.

Right now, what happens in a reduce op is the following:
we have an input and an output. output is always resized to output_keepdim_size. Then, the reduction op is performed with that size, and output is finally either squeezed to output_size or kept at output_keepdim_size depending on what keepdim is.

The problem with what currently happens is that if keepdim=Falseand output initially has size output_size then output will be resized to output_keepdim_size, the reduce op will be performed, and then output will be squeezed to output_size. This resize is not a no-op and will affect output's contiguity.

This PR fixes the issue by always unsqueezing output to output_keepdim_size. This operation preserves output's contiguity.

This fixes the following operations:

  • mean
  • median
  • mode
  • norm
  • prod
  • sum
  • std
  • var
  • max
  • min

Test Plan

New unit tests

dimension + TH_INDEX_BASE);

int in_dims = THTensor_(nDimension)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);

This comment was marked as off-topic.

This comment was marked as off-topic.

@soumith soumith added the 0.3.1 label Feb 6, 2018
@soumith soumith merged commit 237c27c into pytorch:master Feb 6, 2018
soumith pushed a commit that referenced this pull request Feb 7, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.sum() ignores stride in out tensor

2 participants