-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
Originally the primTorch project was targeting stride consistency for reference implementations with PyTorch's eager mode. This has proved to be an issue for several reasons:
1) PyTorch eager's striding is inconsistent.
See #77731 and #77553 for some examples. @ngimel has fixed several of these issues on CUDA, too. See #77610 and #77585.
These issues suggest that stride consistency must not be important enough to strive for it in PyTorch eager today, nor do our users seem particularly affected by inconsistent striding when it does occur.
2) Our elementwise striding behavior is not commutative or associative.
Non-commutative or non-associative properties (our type promotion is also non-associative!) are often a pain, because they mean that different valid reference implementations for an operator must work around the non-commutativity or non-associativity. For example, a valid decomposition of clamp (when both min_value and max_value are specified) is:
minimum(maximum(a, min_value), max_value)
However the application of two elementwise binary operators (maximum and minimum) is not, in general, equivalent to the application of a single elementwise ternary operator (clamp)! We work around the type promotion discrepancy by wrapping every ref in the appropriate type promotion wrapper, but we don't think it's reasonable to develop an analogous striding wrapper. Thus, if we enforce strict stride consistency, we will be limited in how we naturally write references, and any elementwise ternary operator will require a corresponding prim as a design limitation.
3) Operations that sometimes return a view and sometimes return a new tensor are difficult to model, and we tell users not to rely on this behavior for all but the simplest cases.
The reshape documentation directs the user: "you should not depend on [its] copying vs. viewing behavior." On the other hand, contiguous has no such warning.
When tracing, capturing view semantics today depends on careful stride analysis, so if we want to represent views correctly we need a very high fidelity with stride consistency.
Proposal
Given that users don't seem to be demanding absolute stride consistency today, we want to let users write reasonable decompositions without worrying too much about strides, and we'd prefer to model views and striding behavior more simply when tracing, we propose changing our operator semantics to be stride agnostic. @zou3519 has long advocated for reshape always returning a copy from the user's perspective (his full proposal is a more nuanced copy-on-write idea that would preserve reshape's performance), and this proposal would require that work be done, too, to ensure consistency between PyTorch eager and traces. It would require the same behavior from contiguous as well.
It may seem like there's a middle-ground approach where we model contiguity, for instance, and so don't have to modify the contiguous operation, but we don't think there is. Any property we model will have to be "closed" -- that is, determining the property will depend on inputs having it or not -- and our operators are not "closed" w.r.t. to contiguity. We could possibly model additional properties, like "denseness" and whether something is "permuted," but these schemes seem complicated and for little benefit.