KEMBAR78
Add proper shape checking to torch.cat by zou3519 · Pull Request #4087 · pytorch/pytorch · GitHub
Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Dec 8, 2017

Fixes #4071.

THTensor_(catArray) and THCTensor_(catArray) do some strange thing where the shapes of the tensors don't have to be the same in some dimensions in some cases. This adds full shape checking.

Non-empty tensor arguments to torch.cat must all have the same shape, except in the cat dimension.

Test Plan

New unit tests:

  • CPU shape check test
  • CUDA shape check test
  • CUDA torch.cat general test to match CPU's test

Asserts that the inputs have the same size except in the
cat dimension or are empty (or a mix of both).
int64_t first_dims = first->nDimension;
int64_t second_dims = second->nDimension;
THArgCheck(first_dims == second_dims, 0,
"Tensors must have same number of dimensions: got %d and %d",

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Dec 15, 2017

I hate to ask this now that the code is all written, but is there a compelling reason why this couldn't have been implemented in ATen? (I guess, no easy way to insert the checks into the generated code?)

@zou3519
Copy link
Contributor Author

zou3519 commented Dec 15, 2017

Good point, I hadn't thought about implementing the shape checks in aten. I think it shouldn't be too bad to do that: I could rename cat in the cwrap to something and implement a native function that wraps the renamed cat along with performing size checks.

The two downsides I see with this approach are that:

  • tensors don't benefit from the size check (only Variables, since size check is in aten). When tensors/Variables get merged this won't be a problem
  • we'd be doing an extra pass through the list of tensors to cat. Right now the size of the final tensor is computed at the same time as the size checks; separating the two would mean two loops, one for the size checks (in aten), and another to compute the size of the final tensor (in TH). This probably doesn't matter.

Alternatively I could just rewrite cat completely as a native function

@soumith
Copy link
Member

soumith commented Dec 15, 2017

we should keep them in TH/THC. there's no upside (now that the code is written) to do this in ATen.

@soumith soumith merged commit 9394e65 into pytorch:master Dec 18, 2017
@zou3519 zou3519 deleted the cat-size branch January 3, 2018 19:58
@soumith soumith added the 0.3.1 label Feb 4, 2018
soumith pushed a commit that referenced this pull request Feb 7, 2018
* Fix catArray in THTensor

Asserts that the inputs have the same size except in the
cat dimension or are empty (or a mix of both).

* Fix catArray for THCTensor

* Document torch.cat shape checks

* Fix types
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants