KEMBAR78
[dynamo] Remove `mutable_local.source` and index on `VariableTracker` rather than `MutableLocalBase` by StrongerXi · Pull Request #137905 · pytorch/pytorch · GitHub
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions torch/_dynamo/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
root: Optional[torch.nn.Module] = None,
graph_output_var: Optional[str] = None,
tempvars=None,
overridden_sources=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type annotations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

) -> None:
self.root = root
self.top_of_stack: Optional[VariableTracker] = None
Expand All @@ -65,6 +66,10 @@ def __init__(
self.new_var = self.tx.output.new_var
self.mutable_side_effects_from_source = False
self.value_from_source: bool = True
# This serves as a way for codegen to use a different source; we need
# this because sometimes we can't easily modify the original source
# without affecting other components, e.g., guards.
self.overridden_sources: Dict[Source, Source] = overridden_sources or {}

def restore_stack(self, stack_values, *, value_from_source=True):
prior = self.mutable_side_effects_from_source
Expand Down Expand Up @@ -116,7 +121,9 @@ def add_push_null(self, gen_fn, call_function_ex=False):
def __call__(self, value, allow_cache=True):
"""Generate code such that top-of-stack (TOS) is set to value"""
if isinstance(value, Source):
self.call_reconstruct(value)
# If the source needs to be overridden, use the new one.
source = self.overridden_sources.get(value, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need some comments explaining this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

self.call_reconstruct(source)
self.clear_tos()
return

Expand All @@ -130,27 +137,25 @@ def __call__(self, value, allow_cache=True):

if self.mutable_side_effects_from_source:
# this is needed to get aliasing relationships right
# value.mutable_local.source will get mutated to hold `value`
# value.source will get mutated to hold `value`
# mutable_side_effects_from_source=False is used to codegen the mutation
# mutable_side_effects_from_source=True is used to codegen a reference
from .side_effects import MutableSideEffects

if isinstance(value.mutable_local, MutableSideEffects):
self(value.mutable_local.source)
self(value.source)
return

if allow_cache:
if value.mutable_local and value.mutable_local in self.tempvars:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one feels like a separate change from this PR. Though perhaps it is ok because VariableTrackers are now mutable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is me trying to remove indexing on MutableLocal.

output.append(self.create_load(self.tempvars[value.mutable_local]))
self.top_of_stack = value
return
if self.tempvars.get(value) is not None:
output.append(self.create_load(self.tempvars[value]))
self.top_of_stack = value
return

if value.source is not None and allow_cache and self.value_from_source:
self.call_reconstruct(value.source)
# If the source needs to be overridden, use the new one.
source = self.overridden_sources.get(value.source, value.source)
self.call_reconstruct(source)
elif value.is_python_constant() and is_safe_constant(
value.as_python_constant()
):
Expand Down Expand Up @@ -254,8 +259,6 @@ def load_graph_output(self, index):
def add_cache(self, value):
var = self.new_var()
self.tempvars[value] = var
if value.mutable_local:
self.tempvars[value.mutable_local] = var
self._output.append(self.create_store(var))

def foreach(self, items):
Expand Down
57 changes: 44 additions & 13 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,10 @@ def handle_aliases_for_stolen_lists(self, tx):
maybe_gm = self.local_scope.get("self")
stolen_list_names = get_locals_to_steal(maybe_gm)
if not stolen_list_names:
return []
return [], {}

alias_insts = []
needs_alias: Dict[
str, List[Union[VariableTracker, AttributeMutationExisting]]
] = {}
needs_alias: Dict[str, List[VariableTracker]] = {}

queue = [
*tx.stack,
Expand All @@ -926,7 +924,10 @@ def handle_aliases_for_stolen_lists(self, tx):
continue

if not (
isinstance(x, (VariableTracker, AttributeMutationExisting))
(
x not in self.side_effects.store_attr_mutations
or isinstance(x.mutable_local, AttributeMutationExisting)
)
and isinstance(x.source, GetItemSource)
and isinstance(x.source.base, LocalSource)
and x.source.base.local_name in stolen_list_names
Expand All @@ -939,6 +940,7 @@ def handle_aliases_for_stolen_lists(self, tx):
needs_alias[stolen_name].append(x)

visited = {}
overridden_sources: Dict[Source, Source] = {}
for arg in self.graphargs:
if not (
isinstance(arg._example, list)
Expand All @@ -951,6 +953,12 @@ def handle_aliases_for_stolen_lists(self, tx):
list_name = arg.source.local_name
assert list_name in self.code_options["co_varnames"]
for x in needs_alias[list_name]:
# Skip if already handled.
if x.source in overridden_sources:
continue

# A small codegen optimization because we might have different
# VariableTrackers that share the same source.
list_idx = x.source.index
if list_idx not in visited:
alias_name = self.new_var(
Expand All @@ -969,9 +977,14 @@ def handle_aliases_for_stolen_lists(self, tx):
)

# operate on alias, handled by suffix codegen
x.source = LocalSource(visited[list_idx])
old_source = x.source
overridden_sources[old_source] = LocalSource(visited[list_idx])

return alias_insts
# NOTE: we need `overridden_sources` because (1) we want to codegen for
# these list items to use the new local source, but (2) we want to avoid
# updating `source` in place because that might break invariants in
# other parts of Dynamo like guards.
return alias_insts, overridden_sources

def compile_subgraph(
self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
Expand Down Expand Up @@ -1013,7 +1026,8 @@ def compile_subgraph(
self.pregraph_bytecode and self.export
), "export does not support pregraph_bytecode"
prefix_insts.extend(self.pregraph_bytecode)
prefix_insts.extend(self.handle_aliases_for_stolen_lists(tx))
alias_insts, overridden_sources = self.handle_aliases_for_stolen_lists(tx)
prefix_insts.extend(alias_insts)

def append_prefix_insts():
self.add_output_instructions(prefix_insts)
Expand Down Expand Up @@ -1081,7 +1095,7 @@ def append_prefix_insts():
self.random_values_var = self.new_var("random_values")
rand_fn = disable(_get_gen_rand_values_fn(self.random_calls))
rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
codegen = PyCodegen(tx, root)
codegen = PyCodegen(tx, root, overridden_sources=overridden_sources)
random_calls_instructions.extend(
codegen.load_function_name(rand_fn_name, True)
)
Expand Down Expand Up @@ -1119,11 +1133,18 @@ def append_prefix_insts():
)
# restore all the live local vars
self.add_output_instructions(
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
[
PyCodegen(tx, overridden_sources=overridden_sources).create_store(
var
)
for var in reversed(restore_vars)
]
)
else:
graph_output_var = self.new_var("graph_out")
pass1 = PyCodegen(tx, root, graph_output_var)
pass1 = PyCodegen(
tx, root, graph_output_var, overridden_sources=overridden_sources
)
self.codegen_suffix(tx, stack_values, pass1)

# one more time now that we have established tempvars
Expand All @@ -1132,6 +1153,7 @@ def append_prefix_insts():
root,
graph_output_var,
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
overridden_sources=overridden_sources,
)
self.codegen_suffix(tx, stack_values, pass2)

Expand All @@ -1156,12 +1178,21 @@ def append_prefix_insts():

# restore all the live local vars
self.add_output_instructions(
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
[
PyCodegen(tx, overridden_sources=overridden_sources).create_store(
var
)
for var in reversed(restore_vars)
]
)

if stored_graph_output_var:
self.add_output_instructions(
[PyCodegen(tx).create_delete(graph_output_var)]
[
PyCodegen(
tx, overridden_sources=overridden_sources
).create_delete(graph_output_var)
]
)

def codegen_suffix(self, tx, stack_values, cg):
Expand Down
Loading
Loading