From 2bf099d33c288a5bb1b7fe165850da2d9d4287a8 Mon Sep 17 00:00:00 2001
From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Date: Sat, 26 Mar 2022 20:38:25 +0800
Subject: [PATCH] [Feature] Add MultiProcessTestCase (#136)

* [Enhancement] Provide MultiProcessTestCase to test distributed related modules

* remove debugging info

* add timeout property
---
 mmengine/testing/_internal/__init__.py    |   4 +
 mmengine/testing/_internal/distributed.py | 355 ++++++++++++++++++++++
 2 files changed, 359 insertions(+)
 create mode 100644 mmengine/testing/_internal/__init__.py
 create mode 100644 mmengine/testing/_internal/distributed.py

diff --git a/mmengine/testing/_internal/__init__.py b/mmengine/testing/_internal/__init__.py
new file mode 100644
index 00000000..f4528659
--- /dev/null
+++ b/mmengine/testing/_internal/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .distributed import MultiProcessTestCase
+
+__all__ = ['MultiProcessTestCase']
diff --git a/mmengine/testing/_internal/distributed.py b/mmengine/testing/_internal/distributed.py
new file mode 100644
index 00000000..2cc6afb6
--- /dev/null
+++ b/mmengine/testing/_internal/distributed.py
@@ -0,0 +1,355 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Copyright (c) https://github.com/pytorch/pytorch
+# Modified from https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_distributed.py  # noqa: E501
+
+import faulthandler
+import logging
+import multiprocessing
+import sys
+import tempfile
+import threading
+import time
+import traceback
+import types
+import unittest
+from enum import Enum
+from functools import wraps
+from typing import NamedTuple
+from unittest import TestCase
+
+import torch
+from torch.multiprocessing import active_children
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class TestSkip(NamedTuple):
+    exit_code: int
+    message: str
+
+
+TEST_SKIPS = {
+    'backend_unavailable':
+    TestSkip(10, 'Skipped because distributed backend is not available.'),
+    'no_cuda':
+    TestSkip(11, 'CUDA is not available.'),
+    'multi-gpu-2':
+    TestSkip(12, 'Need at least 2 CUDA device'),
+    'generic':
+    TestSkip(
+        13, 'Test skipped at subprocess level, look at subprocess log for '
+        'skip reason'),
+}
+
+# [How does MultiProcessTestCase work?]
+# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
+# default `world_size()` returns 2. Let's take `test_rpc_spawn.py` as an
+# example which inherits from this class. Its `Setup()` methods calls into
+# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()`
+# subprocesses. During the spawn, the main process passes the test name to
+# subprocesses, and the name is acquired from self.id(). The subprocesses
+# then use the provided test function name to retrieve the function attribute
+# from the test instance and run it. The main process simply waits for all
+# subprocesses to join.
+
+
+class MultiProcessTestCase(TestCase):
+    MAIN_PROCESS_RANK = -1
+
+    # This exit code is used to indicate that the test code had an error and
+    # exited abnormally. There are certain tests that might use sys.exit() to
+    # simulate failures and in those cases, we can't have an exit code of 0,
+    # but we still want to ensure we didn't run into any other errors.
+    TEST_ERROR_EXIT_CODE = 10
+
+    # do not early terminate for distributed tests.
+    def _should_stop_test_suite(self) -> bool:
+        return False
+
+    @property
+    def world_size(self) -> int:
+        return 2
+
+    @property
+    def timeout(self) -> int:
+        return 500
+
+    def join_or_run(self, fn):
+
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_PROCESS_RANK:
+                self._join_processes(fn)
+            else:
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    # The main process spawns N subprocesses that run the test.
+    # Constructor patches current instance test method to
+    # assume the role of the main process and join its subprocesses,
+    # or run the underlying test function.
+    def __init__(self, method_name: str = 'runTest') -> None:
+        super().__init__(method_name)
+        fn = getattr(self, method_name)
+        setattr(self, method_name, self.join_or_run(fn))
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.skip_return_code_checks = []  # type: ignore[var-annotated]
+        self.processes = []  # type: ignore[var-annotated]
+        self.rank = self.MAIN_PROCESS_RANK
+        self.file_name = tempfile.NamedTemporaryFile(delete=False).name
+        # pid to pipe consisting of error message from process.
+        self.pid_to_pipe = {}  # type: ignore[var-annotated]
+
+    def tearDown(self) -> None:
+        super().tearDown()
+        for p in self.processes:
+            p.terminate()
+        # Each Process instance holds a few open file descriptors. The unittest
+        # runner creates a new TestCase instance for each test method and keeps
+        # it alive until the end of the entire suite. We must thus reset the
+        # processes to prevent an effective file descriptor leak.
+        self.processes = []
+
+    def _current_test_name(self) -> str:
+        # self.id()
+        # e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        return self.id().split('.')[-1]
+
+    def _start_processes(self, proc) -> None:
+        self.processes = []
+        for rank in range(int(self.world_size)):
+            parent_conn, child_conn = torch.multiprocessing.Pipe()
+            process = proc(
+                target=self.__class__._run,
+                name='process ' + str(rank),
+                args=(rank, self._current_test_name(), self.file_name,
+                      child_conn),
+            )
+            process.start()
+            self.pid_to_pipe[process.pid] = parent_conn
+            self.processes.append(process)
+
+    def _spawn_processes(self) -> None:
+        proc = torch.multiprocessing.get_context('spawn').Process
+        self._start_processes(proc)
+
+    class Event(Enum):
+        GET_TRACEBACK = 1
+
+    @staticmethod
+    def _event_listener(parent_pipe, signal_pipe, rank: int):
+        while True:
+            ready_pipes = multiprocessing.connection.wait(
+                [parent_pipe, signal_pipe])
+
+            if parent_pipe in ready_pipes:
+
+                if parent_pipe.closed:
+                    return
+
+                event = parent_pipe.recv()
+
+                if event == MultiProcessTestCase.Event.GET_TRACEBACK:
+                    # Return traceback to the parent process.
+                    with tempfile.NamedTemporaryFile(mode='r+') as tmp_file:
+                        faulthandler.dump_traceback(tmp_file)
+                        # Flush buffers and seek to read from the beginning
+                        tmp_file.flush()
+                        tmp_file.seek(0)
+                        parent_pipe.send(tmp_file.read())
+
+            if signal_pipe in ready_pipes:
+                return
+
+    @classmethod
+    def _run(cls, rank: int, test_name: str, file_name: str,
+             parent_pipe) -> None:
+        self = cls(test_name)
+
+        self.rank = rank
+        self.file_name = file_name
+        self.run_test(test_name, parent_pipe)
+
+    def run_test(self, test_name: str, parent_pipe) -> None:
+        # Start event listener thread.
+        signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(
+            duplex=False)
+        event_listener_thread = threading.Thread(
+            target=MultiProcessTestCase._event_listener,
+            args=(parent_pipe, signal_recv_pipe, self.rank),
+            daemon=True,
+        )
+        event_listener_thread.start()
+
+        # self.id() == e.g. '__main__.TestDistributed.test_get_rank'
+        # We're retrieving a corresponding test and executing it.
+        try:
+            getattr(self, test_name)()
+        except unittest.SkipTest as se:
+            logger.info(f'Process {self.rank} skipping test {test_name} for '
+                        f'following reason: {str(se)}')
+            sys.exit(TEST_SKIPS['generic'].exit_code)
+        except Exception:
+            logger.error(
+                f'Caught exception: \n{traceback.format_exc()} exiting '
+                f'process {self.rank} with exit code: '
+                f'{MultiProcessTestCase.TEST_ERROR_EXIT_CODE}')
+            # Send error to parent process.
+            parent_pipe.send(traceback.format_exc())
+            sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
+        finally:
+            if signal_send_pipe is not None:
+                signal_send_pipe.send(None)
+
+            assert event_listener_thread is not None
+            event_listener_thread.join()
+            # Close pipe after done with test.
+            parent_pipe.close()
+
+    def _get_timedout_process_traceback(self) -> None:
+        pipes = []
+        for i, process in enumerate(self.processes):
+            if process.exitcode is None:
+                pipe = self.pid_to_pipe[process.pid]
+                try:
+                    pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK)
+                    pipes.append((i, pipe))
+                except ConnectionError as e:
+                    logger.error(
+                        'Encountered error while trying to get traceback '
+                        f'for process {i}: {e}')
+
+        # Wait for results.
+        for rank, pipe in pipes:
+            try:
+                # Wait for traceback
+                if pipe.poll(5):
+                    if pipe.closed:
+                        logger.info(
+                            f'Pipe closed for process {rank}, cannot retrieve '
+                            'traceback')
+                        continue
+
+                    traceback = pipe.recv()
+                    logger.error(f'Process {rank} timed out with traceback: '
+                                 f'\n\n{traceback}')
+                else:
+                    logger.error('Could not retrieve traceback for timed out '
+                                 f'process: {rank}')
+            except ConnectionError as e:
+                logger.error(
+                    'Encountered error while trying to get traceback for '
+                    f'process {rank}: {e}')
+
+    def _join_processes(self, fn) -> None:
+        start_time = time.time()
+        subprocess_error = False
+        try:
+            while True:
+                # check to see if any subprocess exited with an error early.
+                for (i, p) in enumerate(self.processes):
+                    # This is the exit code processes exit with if they
+                    # encountered an exception.
+                    if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
+                        print(
+                            f'Process {i} terminated with exit code '
+                            f'{p.exitcode}, terminating remaining processes.')
+                        _active_children = active_children()
+                        for ac in _active_children:
+                            ac.terminate()
+                        subprocess_error = True
+                        break
+                if subprocess_error:
+                    break
+                # All processes have joined cleanly if they all a valid
+                # exitcode
+                if all([p.exitcode is not None for p in self.processes]):
+                    break
+                # Check if we should time out the test. If so, we terminate
+                # each process.
+                elapsed = time.time() - start_time
+                if elapsed > self.timeout:
+                    self._get_timedout_process_traceback()
+                    print(f'Timing out after {self.timeout} seconds and '
+                          'killing subprocesses.')
+                    for p in self.processes:
+                        p.terminate()
+                    break
+                # Sleep to avoid excessive busy polling.
+                time.sleep(0.1)
+
+            elapsed_time = time.time() - start_time
+
+            if fn in self.skip_return_code_checks:
+                self._check_no_test_errors(elapsed_time)
+            else:
+                self._check_return_codes(elapsed_time)
+        finally:
+            # Close all pipes
+            for pid, pipe in self.pid_to_pipe.items():
+                pipe.close()
+
+    def _check_no_test_errors(self, elapsed_time) -> None:
+        """Checks that we didn't have any errors thrown in the child
+        processes."""
+        for i, p in enumerate(self.processes):
+            if p.exitcode is None:
+                raise RuntimeError(
+                    'Process {} timed out after {} seconds'.format(
+                        i, elapsed_time))
+            self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
+
+    def _check_return_codes(self, elapsed_time) -> None:
+        """Checks that the return codes of all spawned processes match, and
+        skips tests if they returned a return code indicating a skipping
+        condition."""
+        first_process = self.processes[0]
+        # first, we check if there are errors in actual processes
+        # (via TEST_ERROR_EXIT CODE), and raise an exception for those.
+        # the reason we do this is to attempt to raise a more helpful error
+        # message than "Process x terminated/timed out"
+        # TODO: we should pipe the exception of the failed subprocess here.
+        # Currently, the actual exception is displayed as a logging output.
+        errored_processes = [
+            (i, p) for i, p in enumerate(self.processes)
+            if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE
+        ]
+        if errored_processes:
+            error = ''
+            for i, process in errored_processes:
+                # Get error from pipe.
+                error_message = self.pid_to_pipe[process.pid].recv()
+                error += (
+                    'Process {} exited with error code {} and exception:\n{}\n'
+                    .format(i, MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
+                            error_message))
+
+            raise RuntimeError(error)
+        # If no process exited uncleanly, we check for timeouts, and then
+        # ensure each process exited cleanly.
+        for i, p in enumerate(self.processes):
+            if p.exitcode is None:
+                raise RuntimeError(
+                    f'Process {i} terminated or timed out after '
+                    '{elapsed_time} seconds')
+            self.assertEqual(
+                p.exitcode,
+                first_process.exitcode,
+                msg=f'Expect process {i} exit code to match Process 0 exit '
+                'code of {first_process.exitcode}, but got {p.exitcode}')
+        for skip in TEST_SKIPS.values():
+            if first_process.exitcode == skip.exit_code:
+                raise unittest.SkipTest(skip.message)
+        self.assertEqual(
+            first_process.exitcode,
+            0,
+            msg=f'Expected zero exit code but got {first_process.exitcode} '
+            f'for pid: {first_process.pid}')
+
+    @property
+    def is_master(self) -> bool:
+        return self.rank == 0
-- 
GitLab