KEMBAR78
RFC: [primTorch] Stride-agnostic Operator Semantics · Issue #78050 · pytorch/pytorch · GitHub
Skip to content

RFC: [primTorch] Stride-agnostic Operator Semantics #78050

@mruberry

Description

@mruberry

Originally the primTorch project was targeting stride consistency for reference implementations with PyTorch's eager mode. This has proved to be an issue for several reasons:

1) PyTorch eager's striding is inconsistent.

See #77731 and #77553 for some examples. @ngimel has fixed several of these issues on CUDA, too. See #77610 and #77585.

These issues suggest that stride consistency must not be important enough to strive for it in PyTorch eager today, nor do our users seem particularly affected by inconsistent striding when it does occur.

2) Our elementwise striding behavior is not commutative or associative.

Non-commutative or non-associative properties (our type promotion is also non-associative!) are often a pain, because they mean that different valid reference implementations for an operator must work around the non-commutativity or non-associativity. For example, a valid decomposition of clamp (when both min_value and max_value are specified) is:

minimum(maximum(a, min_value), max_value)

However the application of two elementwise binary operators (maximum and minimum) is not, in general, equivalent to the application of a single elementwise ternary operator (clamp)! We work around the type promotion discrepancy by wrapping every ref in the appropriate type promotion wrapper, but we don't think it's reasonable to develop an analogous striding wrapper. Thus, if we enforce strict stride consistency, we will be limited in how we naturally write references, and any elementwise ternary operator will require a corresponding prim as a design limitation.

3) Operations that sometimes return a view and sometimes return a new tensor are difficult to model, and we tell users not to rely on this behavior for all but the simplest cases.

The reshape documentation directs the user: "you should not depend on [its] copying vs. viewing behavior." On the other hand, contiguous has no such warning.

When tracing, capturing view semantics today depends on careful stride analysis, so if we want to represent views correctly we need a very high fidelity with stride consistency.

Proposal

Given that users don't seem to be demanding absolute stride consistency today, we want to let users write reasonable decompositions without worrying too much about strides, and we'd prefer to model views and striding behavior more simply when tracing, we propose changing our operator semantics to be stride agnostic. @zou3519 has long advocated for reshape always returning a copy from the user's perspective (his full proposal is a more nuanced copy-on-write idea that would preserve reshape's performance), and this proposal would require that work be done, too, to ensure consistency between PyTorch eager and traces. It would require the same behavior from contiguous as well.

It may seem like there's a middle-ground approach where we model contiguity, for instance, and so don't have to modify the contiguous operation, but we don't think there is. Any property we model will have to be "closed" -- that is, determining the property will depend on inputs having it or not -- and our operators are not "closed" w.r.t. to contiguity. We could possibly model additional properties, like "denseness" and whether something is "permuted," but these schemes seem complicated and for little benefit.

cc @ezyang @mruberry @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: primTorchmodule: python frontendFor issues relating to PyTorch's Python frontendtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions