-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
docs claim mesh_dim=None should result in returning a list of all PGs. (this is actually a convenient behavior and what I was hoping to use in this case).
Update: re-read and realized the docs accurately describe what happens, which is that there is a special case for 1-D mesh. I'd still like to discuss this as it's convenient to have a method to get a list of all the groups, even if there is only one group.
get_group() is returning a single PG rather than a list. In this case I was running torchtrain with only DP enabled, so a 1-D mesh. That may be a special case, but i think it should still return a list with just the DP group in it. That way i can write something like
for group in world_mesh.get_group():
...
Note: the API name is also a bit off, its weird to have an API named 'get_group' that sometimes gets 'groups'. I'm not sure if its worth changing the behavior and adding a separate API for groups, and making get_group require mesh_dim, or just living with the awkward naming.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k @rohan-varma