diff --git a/mmengine/testing/_internal/__init__.py b/mmengine/testing/_internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4528659a8ff0e342012da67cc5e88fe99c1afa4 --- /dev/null +++ b/mmengine/testing/_internal/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .distributed import MultiProcessTestCase + +__all__ = ['MultiProcessTestCase'] diff --git a/mmengine/testing/_internal/distributed.py b/mmengine/testing/_internal/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc6afb6328344f0127cb9da74aae73f219d23ee --- /dev/null +++ b/mmengine/testing/_internal/distributed.py @@ -0,0 +1,355 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) https://github.com/pytorch/pytorch +# Modified from https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_distributed.py # noqa: E501 + +import faulthandler +import logging +import multiprocessing +import sys +import tempfile +import threading +import time +import traceback +import types +import unittest +from enum import Enum +from functools import wraps +from typing import NamedTuple +from unittest import TestCase + +import torch +from torch.multiprocessing import active_children + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class TestSkip(NamedTuple): + exit_code: int + message: str + + +TEST_SKIPS = { + 'backend_unavailable': + TestSkip(10, 'Skipped because distributed backend is not available.'), + 'no_cuda': + TestSkip(11, 'CUDA is not available.'), + 'multi-gpu-2': + TestSkip(12, 'Need at least 2 CUDA device'), + 'generic': + TestSkip( + 13, 'Test skipped at subprocess level, look at subprocess log for ' + 'skip reason'), +} + +# [How does MultiProcessTestCase work?] +# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by +# default `world_size()` returns 2. Let's take `test_rpc_spawn.py` as an +# example which inherits from this class. Its `Setup()` methods calls into +# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()` +# subprocesses. During the spawn, the main process passes the test name to +# subprocesses, and the name is acquired from self.id(). The subprocesses +# then use the provided test function name to retrieve the function attribute +# from the test instance and run it. The main process simply waits for all +# subprocesses to join. + + +class MultiProcessTestCase(TestCase): + MAIN_PROCESS_RANK = -1 + + # This exit code is used to indicate that the test code had an error and + # exited abnormally. There are certain tests that might use sys.exit() to + # simulate failures and in those cases, we can't have an exit code of 0, + # but we still want to ensure we didn't run into any other errors. + TEST_ERROR_EXIT_CODE = 10 + + # do not early terminate for distributed tests. + def _should_stop_test_suite(self) -> bool: + return False + + @property + def world_size(self) -> int: + return 2 + + @property + def timeout(self) -> int: + return 500 + + def join_or_run(self, fn): + + @wraps(fn) + def wrapper(self): + if self.rank == self.MAIN_PROCESS_RANK: + self._join_processes(fn) + else: + fn() + + return types.MethodType(wrapper, self) + + # The main process spawns N subprocesses that run the test. + # Constructor patches current instance test method to + # assume the role of the main process and join its subprocesses, + # or run the underlying test function. + def __init__(self, method_name: str = 'runTest') -> None: + super().__init__(method_name) + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + + def setUp(self) -> None: + super().setUp() + self.skip_return_code_checks = [] # type: ignore[var-annotated] + self.processes = [] # type: ignore[var-annotated] + self.rank = self.MAIN_PROCESS_RANK + self.file_name = tempfile.NamedTemporaryFile(delete=False).name + # pid to pipe consisting of error message from process. + self.pid_to_pipe = {} # type: ignore[var-annotated] + + def tearDown(self) -> None: + super().tearDown() + for p in self.processes: + p.terminate() + # Each Process instance holds a few open file descriptors. The unittest + # runner creates a new TestCase instance for each test method and keeps + # it alive until the end of the entire suite. We must thus reset the + # processes to prevent an effective file descriptor leak. + self.processes = [] + + def _current_test_name(self) -> str: + # self.id() + # e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' + return self.id().split('.')[-1] + + def _start_processes(self, proc) -> None: + self.processes = [] + for rank in range(int(self.world_size)): + parent_conn, child_conn = torch.multiprocessing.Pipe() + process = proc( + target=self.__class__._run, + name='process ' + str(rank), + args=(rank, self._current_test_name(), self.file_name, + child_conn), + ) + process.start() + self.pid_to_pipe[process.pid] = parent_conn + self.processes.append(process) + + def _spawn_processes(self) -> None: + proc = torch.multiprocessing.get_context('spawn').Process + self._start_processes(proc) + + class Event(Enum): + GET_TRACEBACK = 1 + + @staticmethod + def _event_listener(parent_pipe, signal_pipe, rank: int): + while True: + ready_pipes = multiprocessing.connection.wait( + [parent_pipe, signal_pipe]) + + if parent_pipe in ready_pipes: + + if parent_pipe.closed: + return + + event = parent_pipe.recv() + + if event == MultiProcessTestCase.Event.GET_TRACEBACK: + # Return traceback to the parent process. + with tempfile.NamedTemporaryFile(mode='r+') as tmp_file: + faulthandler.dump_traceback(tmp_file) + # Flush buffers and seek to read from the beginning + tmp_file.flush() + tmp_file.seek(0) + parent_pipe.send(tmp_file.read()) + + if signal_pipe in ready_pipes: + return + + @classmethod + def _run(cls, rank: int, test_name: str, file_name: str, + parent_pipe) -> None: + self = cls(test_name) + + self.rank = rank + self.file_name = file_name + self.run_test(test_name, parent_pipe) + + def run_test(self, test_name: str, parent_pipe) -> None: + # Start event listener thread. + signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe( + duplex=False) + event_listener_thread = threading.Thread( + target=MultiProcessTestCase._event_listener, + args=(parent_pipe, signal_recv_pipe, self.rank), + daemon=True, + ) + event_listener_thread.start() + + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' + # We're retrieving a corresponding test and executing it. + try: + getattr(self, test_name)() + except unittest.SkipTest as se: + logger.info(f'Process {self.rank} skipping test {test_name} for ' + f'following reason: {str(se)}') + sys.exit(TEST_SKIPS['generic'].exit_code) + except Exception: + logger.error( + f'Caught exception: \n{traceback.format_exc()} exiting ' + f'process {self.rank} with exit code: ' + f'{MultiProcessTestCase.TEST_ERROR_EXIT_CODE}') + # Send error to parent process. + parent_pipe.send(traceback.format_exc()) + sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) + finally: + if signal_send_pipe is not None: + signal_send_pipe.send(None) + + assert event_listener_thread is not None + event_listener_thread.join() + # Close pipe after done with test. + parent_pipe.close() + + def _get_timedout_process_traceback(self) -> None: + pipes = [] + for i, process in enumerate(self.processes): + if process.exitcode is None: + pipe = self.pid_to_pipe[process.pid] + try: + pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK) + pipes.append((i, pipe)) + except ConnectionError as e: + logger.error( + 'Encountered error while trying to get traceback ' + f'for process {i}: {e}') + + # Wait for results. + for rank, pipe in pipes: + try: + # Wait for traceback + if pipe.poll(5): + if pipe.closed: + logger.info( + f'Pipe closed for process {rank}, cannot retrieve ' + 'traceback') + continue + + traceback = pipe.recv() + logger.error(f'Process {rank} timed out with traceback: ' + f'\n\n{traceback}') + else: + logger.error('Could not retrieve traceback for timed out ' + f'process: {rank}') + except ConnectionError as e: + logger.error( + 'Encountered error while trying to get traceback for ' + f'process {rank}: {e}') + + def _join_processes(self, fn) -> None: + start_time = time.time() + subprocess_error = False + try: + while True: + # check to see if any subprocess exited with an error early. + for (i, p) in enumerate(self.processes): + # This is the exit code processes exit with if they + # encountered an exception. + if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: + print( + f'Process {i} terminated with exit code ' + f'{p.exitcode}, terminating remaining processes.') + _active_children = active_children() + for ac in _active_children: + ac.terminate() + subprocess_error = True + break + if subprocess_error: + break + # All processes have joined cleanly if they all a valid + # exitcode + if all([p.exitcode is not None for p in self.processes]): + break + # Check if we should time out the test. If so, we terminate + # each process. + elapsed = time.time() - start_time + if elapsed > self.timeout: + self._get_timedout_process_traceback() + print(f'Timing out after {self.timeout} seconds and ' + 'killing subprocesses.') + for p in self.processes: + p.terminate() + break + # Sleep to avoid excessive busy polling. + time.sleep(0.1) + + elapsed_time = time.time() - start_time + + if fn in self.skip_return_code_checks: + self._check_no_test_errors(elapsed_time) + else: + self._check_return_codes(elapsed_time) + finally: + # Close all pipes + for pid, pipe in self.pid_to_pipe.items(): + pipe.close() + + def _check_no_test_errors(self, elapsed_time) -> None: + """Checks that we didn't have any errors thrown in the child + processes.""" + for i, p in enumerate(self.processes): + if p.exitcode is None: + raise RuntimeError( + 'Process {} timed out after {} seconds'.format( + i, elapsed_time)) + self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode) + + def _check_return_codes(self, elapsed_time) -> None: + """Checks that the return codes of all spawned processes match, and + skips tests if they returned a return code indicating a skipping + condition.""" + first_process = self.processes[0] + # first, we check if there are errors in actual processes + # (via TEST_ERROR_EXIT CODE), and raise an exception for those. + # the reason we do this is to attempt to raise a more helpful error + # message than "Process x terminated/timed out" + # TODO: we should pipe the exception of the failed subprocess here. + # Currently, the actual exception is displayed as a logging output. + errored_processes = [ + (i, p) for i, p in enumerate(self.processes) + if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE + ] + if errored_processes: + error = '' + for i, process in errored_processes: + # Get error from pipe. + error_message = self.pid_to_pipe[process.pid].recv() + error += ( + 'Process {} exited with error code {} and exception:\n{}\n' + .format(i, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, + error_message)) + + raise RuntimeError(error) + # If no process exited uncleanly, we check for timeouts, and then + # ensure each process exited cleanly. + for i, p in enumerate(self.processes): + if p.exitcode is None: + raise RuntimeError( + f'Process {i} terminated or timed out after ' + '{elapsed_time} seconds') + self.assertEqual( + p.exitcode, + first_process.exitcode, + msg=f'Expect process {i} exit code to match Process 0 exit ' + 'code of {first_process.exitcode}, but got {p.exitcode}') + for skip in TEST_SKIPS.values(): + if first_process.exitcode == skip.exit_code: + raise unittest.SkipTest(skip.message) + self.assertEqual( + first_process.exitcode, + 0, + msg=f'Expected zero exit code but got {first_process.exitcode} ' + f'for pid: {first_process.pid}') + + @property + def is_master(self) -> bool: + return self.rank == 0