-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Support multivariate TransformedDistributions #4937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @apaszke |
torch/distributions/transforms.py
Outdated
| result += part.log_abs_det_jacobian(x, y) | ||
| term = part.log_abs_det_jacobian(x, y) | ||
| for _ in range(self.event_dim - part.event_dim): | ||
| term = term.sum(-1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/distributions/transforms.py
Outdated
| the codomain. Transforms that are not bijective should at least | ||
| maintain the weaker pseudoinverse properties | ||
| ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. | ||
| event_dim (int): Number of dimensions in the transform ``event_shape``. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
LGTM, can you please resolve the conflicts? |
This reverts commit ca5071d.
* Revert "Clarify grad_input_mask documentation in derivatives.yaml (#4963)" This reverts commit 6f3266b. * Revert "fix triu and tril for zero-strided inputs on gpu (#4962)" This reverts commit 6c197c2. * Revert "Add mutex for CPU RNG and move TH to C++ (#4041)" This reverts commit 96239dd. * Revert "Support multivariate TransformedDistributions (#4937)" This reverts commit ca5071d. * Revert "Only check that arguments are Variables in VariableType (#4943)" This reverts commit d444379. * Revert "torch.set_num_threads sets MKL option too (#4949)" This reverts commit 2aaeec0.
Reviewed by @rachtsingh and @alicanb at probtorch#116
This adds an
.event_dimattribute to allTransforms and correctly handles event shape inTransformedDistribution.log_prob()andComposeTransform.log_abs_det_jacobian(). Cases we need to handle are:TransformedDistribution.base_disthas a largerevent_dimthan its transforms, we need to sum out the rightmost dimensions in thetransform.log_abs_det_jacobian()s, otherwise there will be a shape error.TransformedDistribution.base_disthas a smallerevent_dimthan its transforms (e.g. when implementingMultivariateNormalas anAffineOperatorTransformof univariateNormal), we need to sum out the rightmost dimensions ofbase_dist.log_prob().event_dim, we need to sum out all but the largest dim.This PR also includes fixes to
ComposeTransform.event_dimandTransformedDistribution.event_shapeto support multivariate transforms.This PR was the result of issues that came up in @rachtsingh's probtorch#113 and in @fritzo's refactoring of
InverseAutoregressiveFlowin Pyro as we build on top of torch.distributions.transforms.Tested
TransformedDistributionTransformshapes