KEMBAR78
Support multivariate TransformedDistributions by fritzo · Pull Request #4937 · pytorch/pytorch · GitHub
Skip to content

Conversation

@fritzo
Copy link
Collaborator

@fritzo fritzo commented Jan 30, 2018

Reviewed by @rachtsingh and @alicanb at probtorch#116

This adds an .event_dim attribute to all Transforms and correctly handles event shape in TransformedDistribution.log_prob() and ComposeTransform.log_abs_det_jacobian(). Cases we need to handle are:

  • When TransformedDistribution.base_dist has a larger event_dim than its transforms, we need to sum out the rightmost dimensions in the transform.log_abs_det_jacobian()s, otherwise there will be a shape error.
  • When TransformedDistribution.base_dist has a smaller event_dim than its transforms (e.g. when implementing MultivariateNormal as an AffineOperatorTransform of univariate Normal), we need to sum out the rightmost dimensions of base_dist.log_prob().
  • When transforms have differing event_dim, we need to sum out all but the largest dim.

This PR also includes fixes to ComposeTransform.event_dim and TransformedDistribution.event_shape to support multivariate transforms.

This PR was the result of issues that came up in @rachtsingh's probtorch#113 and in @fritzo's refactoring of InverseAutoregressiveFlow in Pyro as we build on top of torch.distributions.transforms.

Tested

  • More tests for TransformedDistribution
  • New tests for Transform shapes

@fritzo
Copy link
Collaborator Author

fritzo commented Jan 30, 2018

cc @apaszke

result += part.log_abs_det_jacobian(x, y)
term = part.log_abs_det_jacobian(x, y)
for _ in range(self.event_dim - part.event_dim):
term = term.sum(-1)

This comment was marked as off-topic.

This comment was marked as off-topic.

the codomain. Transforms that are not bijective should at least
maintain the weaker pseudoinverse properties
``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
event_dim (int): Number of dimensions in the transform ``event_shape``.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Jan 31, 2018

LGTM, can you please resolve the conflicts?

@apaszke apaszke closed this Jan 31, 2018
@apaszke apaszke reopened this Jan 31, 2018
@apaszke apaszke merged commit ca5071d into pytorch:master Jan 31, 2018
ssnl added a commit to ssnl/pytorch that referenced this pull request Jan 31, 2018
soumith pushed a commit that referenced this pull request Jan 31, 2018
* Revert "Clarify grad_input_mask documentation in derivatives.yaml (#4963)"

This reverts commit 6f3266b.

* Revert "fix triu and tril for zero-strided inputs on gpu (#4962)"

This reverts commit 6c197c2.

* Revert "Add mutex for CPU RNG and move TH to C++ (#4041)"

This reverts commit 96239dd.

* Revert "Support multivariate TransformedDistributions (#4937)"

This reverts commit ca5071d.

* Revert "Only check that arguments are Variables in VariableType (#4943)"

This reverts commit d444379.

* Revert "torch.set_num_threads sets MKL option too (#4949)"

This reverts commit 2aaeec0.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: distributions Related to torch.distributions open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants