From 8864bd88d7ba11c984ed00ea4904f721bd544690 Mon Sep 17 00:00:00 2001
From: Austin Welch <austinmw@users.noreply.github.com>
Date: Sat, 8 Oct 2022 07:50:32 -0400
Subject: [PATCH] [Feats]: Add smddp dist backend option (#579)

* Add smddp dist backend option

* [Dev]: Upgrade pre commit hooks (#576)

* Upgrade the versions of pre-commit-hooks

* update zh-cn.yaml

* [Docs] Fix the docstring of model sub-package (#573)

* [Doc]: Update config.md (#562)

* Update config.md

* Update config.md

* [Doc] delete the error comment  in docs (#514)

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Zhengfei-0311 <78833899+Zhengfei-0311@users.noreply.github.com>
Co-authored-by: vansin <msnode@163.com>
---
 mmengine/dist/utils.py | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py
index e0cf8c11..285c0f37 100644
--- a/mmengine/dist/utils.py
+++ b/mmengine/dist/utils.py
@@ -94,6 +94,15 @@ def _init_dist_mpi(backend, **kwargs) -> None:
             'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
         **kwargs: keyword arguments are passed to ``init_process_group``.
     """
+    if backend == 'smddp':
+        try:
+            import smdistributed.dataparallel.torch.torch_smddp  # noqa: F401
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError(
+                'Please use an Amazon SageMaker DLC to access smdistributed: '
+                'https://github.com/aws/deep-learning-containers/blob/master'
+                '/available_images.md#sagemaker-framework-containers'
+                '-sm-support-only') from e
     local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
     torch.cuda.set_device(local_rank)
     if 'MASTER_PORT' not in os.environ:
@@ -433,6 +442,8 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
     elif backend == 'cncl':
         import torch_mlu  # noqa: F401
         return torch.device('mlu', torch.mlu.current_device())
+    elif backend == 'smddp':
+        return torch.device('cuda', torch.cuda.current_device())
     else:
         # GLOO and MPI backends use cpu device by default
         return torch.device('cpu')
-- 
GitLab