-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Fix default timeouts for python entrypoints (e.g. init_process_group) #112893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
faa01ae
9048701
b1db594
d1703d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,23 @@ | ||
| from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT | ||
| from datetime import timedelta | ||
| from typing import Optional | ||
|
|
||
| __all__ = ['default_pg_timeout', 'default_pg_nccl_timeout'] | ||
|
|
||
| # Default process group wide timeout, if applicable. | ||
| # This only applies to the gloo and nccl backends | ||
| # (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1). | ||
| # This only applies to the non-nccl backends | ||
| # To make an attempt at backwards compatibility with THD, we use an | ||
| # extraordinarily high default timeout, given that THD did not have timeouts. | ||
| default_pg_timeout = _DEFAULT_PG_TIMEOUT | ||
| default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT | ||
| # Separate timeout for PGNCCL mainly becuase it's always been that way in the C++ layer, but until recently | ||
| # there was one default that applied across all backends in the python layer. | ||
| # Later, we could consider merging them back together at the c++ layer if we can align on a same value. | ||
| # (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1). | ||
|
|
||
| try: | ||
| from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT | ||
| default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT | ||
| except ImportError: | ||
| # if C++ NCCL support is not compiled, we don't have access to the default nccl value. | ||
| # if anyone is actually trying to use nccl in this state, it should error. | ||
| default_pg_nccl_timeout = None |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,7 @@ | |
| get_debug_level, | ||
| Work | ||
| ) | ||
| from .constants import default_pg_timeout | ||
| from .constants import default_pg_timeout, default_pg_nccl_timeout | ||
| from .c10d_logger import _exception_logger, _time_logger | ||
| from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 | ||
| DistStoreError = torch._C._DistStoreError | ||
|
|
@@ -575,6 +575,20 @@ class GroupMember(metaclass=_WorldMeta): | |
| NON_GROUP_MEMBER = -100 | ||
|
|
||
|
|
||
| def _get_default_timeout(backend: Backend) -> timedelta: | ||
| # see note on nccl vs other backend timeout (constants.py) | ||
| if backend == Backend.NCCL: | ||
| assert isinstance(default_pg_nccl_timeout, timedelta), "no NCCL default timeout, is NCCL support compiled?" | ||
| return default_pg_nccl_timeout | ||
| else: | ||
| return default_pg_timeout | ||
|
|
||
| def _check_valid_timeout(timeout: Any) -> None: | ||
| if not isinstance(timeout, timedelta): | ||
| raise TypeError( | ||
| f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" | ||
| ) | ||
|
|
||
| # Default process group state | ||
| _default_pg_init_method = None | ||
|
|
||
|
|
@@ -1015,7 +1029,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> str: | |
| def init_process_group( | ||
| backend: Union[str, Backend] = None, | ||
| init_method: Optional[str] = None, | ||
| timeout: timedelta = default_pg_timeout, | ||
| timeout: Optional[timedelta] = None, | ||
| world_size: int = -1, | ||
| rank: int = -1, | ||
| store: Optional[Store] = None, | ||
|
|
@@ -1060,26 +1074,14 @@ def init_process_group( | |
| to exchange connection/address information. | ||
| Mutually exclusive with ``init_method``. | ||
| timeout (timedelta, optional): Timeout for operations executed against | ||
| the process group. Default value equals 30 minutes. | ||
| This is applicable for the ``gloo`` backend. For ``nccl``, this is | ||
| applicable only if the environment variable ``NCCL_BLOCKING_WAIT`` | ||
| or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When | ||
| ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the | ||
| process will block and wait for collectives to complete before | ||
| throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set, | ||
| this is the duration after which collectives will be aborted | ||
| asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT`` | ||
| will provide errors to the user which can be caught and handled, | ||
| but due to its blocking nature, it has a performance overhead. On | ||
| the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little | ||
| performance overhead, but crashes the process on errors. This is | ||
| done since CUDA execution is async and it is no longer safe to | ||
| continue executing user code since failed async NCCL operations | ||
| might result in subsequent CUDA operations running on corrupted | ||
| data. Only one of these two environment variables should be set. | ||
| For ``ucc``, blocking wait is supported similar to NCCL. However, | ||
| async error handling is done differently since with UCC we have | ||
| progress thread and not watch-dog thread. | ||
| the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. | ||
| This is the duration after which collectives will be aborted asynchronously and the process will crash. | ||
| When ``NCCL_ASYNC_ERROR_HANDLING`` is set, this is the duration after which collectives will be aborted | ||
| asynchronously and the process will crash. | ||
| This is done since CUDA execution is async and it is no longer safe to continue executing user code since | ||
| failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. | ||
| When NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. | ||
|
|
||
| group_name (str, optional, deprecated): Group name. This argument is ignored | ||
| pg_options (ProcessGroupOptions, optional): process group options | ||
| specifying what additional options need to be passed in during | ||
|
|
@@ -1105,11 +1107,6 @@ def init_process_group( | |
| global _backend | ||
| global _default_pg_init_method | ||
|
|
||
| if not isinstance(timeout, timedelta): | ||
| raise TypeError( | ||
| "Expected timeout argument to be of type datetime.timedelta" | ||
| ) | ||
|
|
||
| if GroupMember.WORLD is not None: | ||
| raise ValueError("trying to initialize the default process group twice!") | ||
|
|
||
|
|
@@ -1128,6 +1125,11 @@ def init_process_group( | |
| else: | ||
| backend = Backend("undefined") | ||
|
|
||
| if timeout is None: | ||
| timeout = _get_default_timeout(backend) | ||
|
|
||
| _check_valid_timeout(timeout) | ||
|
|
||
| """ | ||
| Group name is not visible to users unless they access | ||
| internals of c10d. This means we can ignore the value | ||
|
|
@@ -1205,7 +1207,7 @@ def _new_process_group_helper( | |
| store, | ||
| group_name, | ||
| pg_options=None, | ||
| timeout=default_pg_timeout, | ||
| timeout=None, | ||
| pg_tag=None | ||
| ): | ||
| """ | ||
|
|
@@ -1225,10 +1227,8 @@ def _new_process_group_helper( | |
| "created, please use a different group name" | ||
| ) | ||
|
|
||
| if not isinstance(timeout, timedelta): | ||
| raise TypeError( | ||
| "Expected timeout argument to be of type datetime.timedelta" | ||
| ) | ||
| # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value | ||
| _check_valid_timeout(timeout) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. init_process_group checks the timeout and then calls _new_process_group_helper other 'new_group' apis also call _new_process_Group_helper directly. My first thought was to move all the checking inside the helper and take it out of all the outer functions. I still think this is a good idea. However, this is a larger change since init_process_group also has handling for |
||
|
|
||
| if pg_tag not in [None, ""]: | ||
| # creating with the same tag and rank set results in the same underlying PG | ||
|
|
@@ -3789,7 +3789,16 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals | |
| raise ValueError("monitored_barrier is only implemented for GLOO backend.") | ||
|
|
||
| if timeout is None: | ||
| timeout = default_pg_timeout | ||
| timeout = _get_default_timeout(get_backend(group)) | ||
| elif isinstance(timeout, float): | ||
| # TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format? | ||
| warnings.warn( | ||
| "Please specify timeout arg as a timedelta. " | ||
| f"Converting current value of {timeout} assuming it represents seconds", | ||
| ) | ||
| timeout = timedelta(seconds=timeout) | ||
|
|
||
| _check_valid_timeout(timeout) | ||
|
|
||
| group_to_use = _get_default_group() if group is None else group | ||
| return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks) | ||
|
|
@@ -3803,6 +3812,8 @@ def _create_process_group_wrapper( | |
| world_size: int, | ||
| timeout: timedelta = default_pg_timeout, | ||
| ): | ||
| # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate... | ||
|
|
||
| # Create a separate prefix store for the helper process group. | ||
| prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}" | ||
| store = PrefixStore(prefix, store) | ||
|
|
@@ -3832,7 +3843,7 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend: | |
|
|
||
|
|
||
| @_time_logger | ||
| def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None, use_local_synchronization=False): | ||
| def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False): | ||
| """ | ||
| Creates a new distributed group. | ||
|
|
||
|
|
@@ -3855,24 +3866,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N | |
| Args: | ||
| ranks (list[int]): List of ranks of group members. If ``None``, will be | ||
| set to all ranks. Default is ``None``. | ||
| timeout (timedelta, optional): Timeout for operations executed against | ||
| the process group. Default value equals 30 minutes. | ||
| This is applicable for the ``gloo`` backend. For ``nccl``, this is | ||
| applicable only if the environment variable ``NCCL_BLOCKING_WAIT`` | ||
| or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When | ||
| ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the | ||
| process will block and wait for collectives to complete before | ||
| throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set, | ||
| this is the duration after which collectives will be aborted | ||
| asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT`` | ||
| will provide errors to the user which can be caught and handled, | ||
| but due to its blocking nature, it has a performance overhead. On | ||
| the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little | ||
| performance overhead, but crashes the process on errors. This is | ||
| done since CUDA execution is async and it is no longer safe to | ||
| continue executing user code since failed async NCCL operations | ||
| might result in subsequent CUDA operations running on corrupted | ||
| data. Only one of these two environment variables should be set. | ||
| timeout (timedelta, optional): see `init_process_group` for details and default value. | ||
| backend (str or Backend, optional): The backend to use. Depending on | ||
| build-time configurations, valid values are ``gloo`` and ``nccl``. | ||
| By default uses the same backend as the global group. This field | ||
|
|
@@ -3908,7 +3902,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N | |
|
|
||
| def _new_group_with_tag( | ||
| ranks=None, | ||
| timeout=default_pg_timeout, | ||
| timeout=None, | ||
| backend=None, | ||
| pg_options=None, | ||
| pg_tag=None, | ||
|
|
@@ -3927,11 +3921,19 @@ def _new_group_with_tag( | |
| global_rank = default_pg.rank() | ||
| global_world_size = default_pg.size() | ||
|
|
||
|
|
||
| # Default to the same backend as the global process group | ||
| # if the backend is not specified. | ||
| if not backend: | ||
| backend = default_backend | ||
| backend = Backend(backend) | ||
|
|
||
| # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we can reduce the appearance of this defaulting/validation which is sad... Any idea? |
||
| # which may just pass their timeout value (or None) | ||
| if timeout is None: | ||
| timeout = _get_default_timeout(backend) | ||
| _check_valid_timeout(timeout) | ||
|
|
||
| if use_local_synchronization: | ||
| # MPI backend doesn't have have a way for us to perform a partial sync | ||
| if backend == Backend.MPI: | ||
|
|
@@ -4013,7 +4015,7 @@ def _new_group_with_tag( | |
| def new_subgroups( | ||
| group_size=None, | ||
| group=None, | ||
| timeout=default_pg_timeout, | ||
| timeout=None, | ||
| backend=None, | ||
| pg_options=None, | ||
| ): | ||
|
|
@@ -4052,24 +4054,7 @@ def new_subgroups( | |
| the default subgroup size is equal to the number of devices on each machine, | ||
| based on the assumption that each machine has exactly the same | ||
| number of devices. Default is ``None``. | ||
| timeout (timedelta, optional): Timeout for operations executed against | ||
| the process group. Default value equals 30 minutes. | ||
| This is applicable for the ``gloo`` backend. For ``nccl``, this is | ||
| applicable only if the environment variable ``NCCL_BLOCKING_WAIT`` | ||
| or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When | ||
| ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the | ||
| process will block and wait for collectives to complete before | ||
| throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set, | ||
| this is the duration after which collectives will be aborted | ||
| asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT`` | ||
| will provide errors to the user which can be caught and handled, | ||
| but due to its blocking nature, it has a performance overhead. On | ||
| the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little | ||
| performance overhead, but crashes the process on errors. This is | ||
| done since CUDA execution is async and it is no longer safe to | ||
| continue executing user code since failed async NCCL operations | ||
| might result in subsequent CUDA operations running on corrupted | ||
| data. Only one of these two environment variables should be set. | ||
| timeout (timedelta, optional): see `init_process_group` for details and default value. | ||
| backend (str or Backend, optional): The backend to use. Depending on | ||
| build-time configurations, valid values are ``gloo`` and ``nccl``. | ||
| By default uses the same backend as the global group. This field | ||
|
|
@@ -4144,7 +4129,7 @@ def new_subgroups( | |
|
|
||
| def new_subgroups_by_enumeration( | ||
| ranks_per_subgroup_list, | ||
| timeout=default_pg_timeout, | ||
| timeout=None, | ||
| backend=None, | ||
| pg_options=None, | ||
| ): | ||
|
|
@@ -4172,25 +4157,8 @@ def new_subgroups_by_enumeration( | |
| Args: | ||
| ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of | ||
| group members. | ||
| timeout (timedelta, optional): Timeout for operations executed against | ||
| the process group. Default value equals 30 minutes. | ||
| This is applicable for the ``gloo`` backend. For ``nccl``, this is | ||
| applicable only if the environment variable ``NCCL_BLOCKING_WAIT`` | ||
| or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When | ||
| ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the | ||
| process will block and wait for collectives to complete before | ||
| throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set, | ||
| this is the duration after which collectives will be aborted | ||
| asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT`` | ||
| will provide errors to the user which can be caught and handled, | ||
| but due to its blocking nature, it has a performance overhead. On | ||
| the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little | ||
| performance overhead, but crashes the process on errors. This is | ||
| done since CUDA execution is async and it is no longer safe to | ||
| continue executing user code since failed async NCCL operations | ||
| might result in subsequent CUDA operations running on corrupted | ||
| data. Only one of these two environment variables should be set. | ||
| backend (str or Backend, optional): The backend to use. Depending on | ||
| timeout (timedelta, optional): see `init_process_group` for details and default value. | ||
| backend (str or Backend, optional): The backend to use. Depending on | ||
| build-time configurations, valid values are ``gloo`` and ``nccl``. | ||
| By default uses the same backend as the global group. This field | ||
| should be given as a lowercase string (e.g., ``"gloo"``), which can | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicated?