KEMBAR78
[pipelining] clean up stage functions by H-Huang · Pull Request #140418 · pytorch/pytorch · GitHub
Skip to content

Conversation

@H-Huang
Copy link
Member

@H-Huang H-Huang commented Nov 12, 2024

Stack from ghstack (oldest at bottom):

Clean up methods related to stage input/output shape verification which are no longer needed

cc @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140418

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c90018d with merge base 0a0915f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 12, 2024
H-Huang added a commit that referenced this pull request Nov 12, 2024
ghstack-source-id: ea0af1d
Pull Request resolved: #140418
@H-Huang H-Huang added the release notes: distributed (pipeline) release notes category label Nov 12, 2024
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc the 4 helpers at the top are deleted because they were only used by validate-shape helper at the bottom?

for validate shape, i guess we are not using it anywhere?

if the user uses shape-inference which is encouraged, then validate shape seems pointless. And if they do not use shape-inference, i think shape validation is nice to have but i kinda prefer deleting all this code if possible. Do we have additional shape validation helpers that don't do communication? IIRC i added something that recorded the expected shape at init time and errored at runtime if it mismatches. If we have that then i think this is all safe to delete.

@H-Huang
Copy link
Member Author

H-Huang commented Nov 12, 2024

iiuc the 4 helpers at the top are deleted because they were only used by validate-shape helper at the bottom? for validate shape, i guess we are not using it anywhere?

Yes to these both

i added something that recorded the expected shape at init time and errored at runtime if it mismatches.

I just checked and yeah I see this as part of

def _validate_fwd_input(self, args, kwargs):
"""Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage."""
if self.is_first:
# TODO why is there a separate recv_info for each pipeline chunk?
# kwen2501: to avoid passing a `fwd_chunk_id` to this function, we
# check all chunks against args_recv_info[0]
expected_args = self.args_recv_info[0]
else:
# We don't check inputs for non-0 stages assuming they don't accept
# user inputs in canonical pipeline scenarios
return
if len(kwargs):
# TODO- need a mapping of kwarg to position in self.args_recv_info
# Without it, we are not 100% sure how to match the args and
# expected_args.
return
# TODO- need a mapping of kwarg to position in self.args_recv_info
# maybe it's impossible to tell whether the len mismatches because
# (a) the user passed an extra arg or missed an arg
# (b) the user did not pass a kwarg, which has a default value baked into expected_args
expected_tensors_meta = [
e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer
for e in expected_args
]
validate_tensors_metadata(
f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args
)
def _validate_fwd_outputs(self, outputs: Tuple[torch.Tensor, ...]):
"""Raises a RuntimeError if this stage produces an output of unexpected shape/dtype.
Most likely, this could be cause either by incorrect user specification of output shapes, or becuase
shape inference was done on the original model but then at runtime the model is wrapped with something like
mixed precision which changes output dtype.
"""
expected_tensors_meta = self.get_outputs_meta()
validate_tensors_metadata(
f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs
)

@H-Huang
Copy link
Member Author

H-Huang commented Nov 12, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 12, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Clean up methods related to stage input/output shape verification which are no longer needed

Pull Request resolved: pytorch#140418
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#140019
@github-actions github-actions bot deleted the gh/H-Huang/154/head branch December 14, 2024 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants