Skip to content
Snippets Groups Projects
Unverified Commit 7e1d7af2 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Refactor] Refactor code structure (#395)

* Rename data to structure

* adjust the way to import module

* adjust the way to import module

* rename Structure to Data Structures in docs api

* rename structure to structures

* support using some modules of mmengine without torch

* fix circleci config

* fix circleci config

* fix registry ut

* minor fix

* move init method from model/utils to model/weight_init.py

* move init method from model/utils to model/weight_init.py

* move sync_bn to model

* move functions depending on torch to dl_utils

* format import

* fix logging ut

* add weight init in model/__init__.py

* move get_config and get_model to mmengine/hub

* move log_processor.py to mmengine/runner

* fix ut

* Add TimeCounter in dl_utils/__init__.py
parent 486d8cda
No related branches found
No related tags found
No related merge requests found
Showing
with 97 additions and 39 deletions
...@@ -17,6 +17,34 @@ jobs: ...@@ -17,6 +17,34 @@ jobs:
pip install interrogate pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmengine interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmengine
build_without_torch:
parameters:
# The python version must match available image tags in
# https://circleci.com/developer/images/image/cimg/python
python:
type: string
default: "3.7.4"
docker:
- image: cimg/python:<< parameters.python >>
resource_class: large
steps:
- checkout
- run:
name: Upgrade pip
command: |
python -V
python -m pip install pip --upgrade
python -m pip --version
- run:
name: Install mmengine dependencies
command: python -m pip install -r requirements.txt
- run:
name: Build and install
command: python -m pip install -e .
- run:
name: Run unit tests
command: python -m pytest tests/test_config tests/test_registry tests/test_fileio tests/test_logging tests/test_utils --ignore=tests/test_utils/test_dl_utils
build_cpu: build_cpu:
parameters: parameters:
# The python version must match available image tags in # The python version must match available image tags in
...@@ -101,12 +129,16 @@ workflows: ...@@ -101,12 +129,16 @@ workflows:
unit_tests: unit_tests:
jobs: jobs:
- lint - lint
- build_without_torch:
requires:
- lint
- build_cpu: - build_cpu:
name: build_cpu_th1.8_py3.7 name: build_cpu_th1.8_py3.7
torch: 1.8.0 torch: 1.8.0
torchvision: 0.9.0 torchvision: 0.9.0
requires: requires:
- lint - lint
- build_without_torch
- hold: - hold:
type: approval # <<< This key-value pair will set your workflow to a status of "On Hold" type: approval # <<< This key-value pair will set your workflow to a status of "On Hold"
requires: requires:
......
...@@ -23,13 +23,13 @@ Optimizer ...@@ -23,13 +23,13 @@ Optimizer
.. automodule:: mmengine.optim .. automodule:: mmengine.optim
:members: :members:
Data Data Structures
-------- ----------------
.. automodule:: mmengine.data .. automodule:: mmengine.structures
:members: :members:
Dataset Dataset
-------- ------------
.. automodule:: mmengine.dataset .. automodule:: mmengine.dataset
:members: :members:
......
...@@ -23,9 +23,14 @@ Optimizer ...@@ -23,9 +23,14 @@ Optimizer
.. automodule:: mmengine.optim .. automodule:: mmengine.optim
:members: :members:
Data Data Structures
-------- ----------------
.. automodule:: mmengine.data .. automodule:: mmengine.structures
:members:
Dataset
------------
.. automodule:: mmengine.dataset
:members: :members:
Distributed Distributed
...@@ -42,3 +47,13 @@ Model ...@@ -42,3 +47,13 @@ Model
-------- --------
.. automodule:: mmengine.model .. automodule:: mmengine.model
:members: :members:
Visualization
--------
.. automodule:: mmengine.visualization
:members:
Utils
--------
.. automodule:: mmengine.utils
:members:
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa # flake8: noqa
from .config import * from .config import *
from .data import *
from .dataset import *
from .device import *
from .fileio import * from .fileio import *
from .hooks import *
from .logging import * from .logging import *
from .registry import * from .registry import *
from .runner import *
from .utils import * from .utils import *
from .version import __version__, version_info from .version import __version__, version_info
from .visualization import *
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction from .config import Config, ConfigDict, DictAction
from .get_config_model import get_config, get_model
__all__ = ['Config', 'ConfigDict', 'DictAction', 'get_config', 'get_model'] __all__ = ['Config', 'ConfigDict', 'DictAction']
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
from .base_dataset import BaseDataset, Compose, force_full_init from .base_dataset import BaseDataset, Compose, force_full_init
from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset
from .sampler import DefaultSampler, InfiniteSampler
from .utils import pseudo_collate, worker_init_fn
__all__ = [
'BaseDataset', 'Compose', 'force_full_init', 'ClassBalancedDataset',
'ConcatDataset', 'RepeatDataset', 'DefaultSampler', 'InfiniteSampler',
'worker_init_fn', 'pseudo_collate'
]
File moved
File moved
...@@ -18,7 +18,7 @@ from .utils import (get_world_size, get_rank, get_backend, get_dist_info, ...@@ -18,7 +18,7 @@ from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
get_default_group, barrier, get_data_device, get_default_group, barrier, get_data_device,
get_comm_device, cast_data_device) get_comm_device, cast_data_device)
from mmengine.utils.version_utils import digit_version from mmengine.utils.version_utils import digit_version
from mmengine.utils.parrots_wrapper import TORCH_VERSION from mmengine.utils.dl_utils import TORCH_VERSION
def _get_reduce_op(name: str) -> torch_dist.ReduceOp: def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Iterator, List, Optional, Sequence, Union from typing import Iterator, List, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.registry import EVALUATOR, METRICS from mmengine.registry import EVALUATOR, METRICS
from mmengine.structures import BaseDataElement
from .metric import BaseMetric from .metric import BaseMetric
......
...@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Union ...@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Union
from torch import Tensor from torch import Tensor
from mmengine.data import BaseDataElement
from mmengine.dist import (broadcast_object_list, collect_results, from mmengine.dist import (broadcast_object_list, collect_results,
is_main_process) is_main_process)
from mmengine.fileio import dump from mmengine.fileio import dump
from mmengine.logging import print_log from mmengine.logging import print_log
from mmengine.registry import METRICS from mmengine.registry import METRICS
from mmengine.structures import BaseDataElement
class BaseMetric(metaclass=ABCMeta): class BaseMetric(metaclass=ABCMeta):
......
...@@ -3,8 +3,8 @@ from typing import Optional, Sequence, Union ...@@ -3,8 +3,8 @@ from typing import Optional, Sequence, Union
import torch import torch
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]] DATA_BATCH = Optional[Sequence[dict]]
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
from mmengine.data import BaseDataElement from mmengine.structures import BaseDataElement
DATA_BATCH = Optional[Sequence[dict]] DATA_BATCH = Optional[Sequence[dict]]
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
import time import time
from typing import Optional, Sequence, Union from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]] DATA_BATCH = Optional[Sequence[dict]]
......
...@@ -4,10 +4,10 @@ import os.path as osp ...@@ -4,10 +4,10 @@ import os.path as osp
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.fileio import FileClient, dump from mmengine.fileio import FileClient, dump
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_tuple_of, scandir from mmengine.utils import is_tuple_of, scandir
DATA_BATCH = Optional[Sequence[dict]] DATA_BATCH = Optional[Sequence[dict]]
......
...@@ -5,10 +5,10 @@ from typing import Optional, Sequence, Tuple ...@@ -5,10 +5,10 @@ from typing import Optional, Sequence, Tuple
import cv2 import cv2
import numpy as np import numpy as np
from mmengine.data import BaseDataElement
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from mmengine.utils.misc import tensor2imgs from mmengine.structures import BaseDataElement
from mmengine.utils.dl_utils import tensor2imgs
# TODO: Due to interface changes, the current class # TODO: Due to interface changes, the current class
......
# Copyright (c) OpenMMLab. All rights reserved.
from .hub import get_config, get_model
__all__ = ['get_config', 'get_model']
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
import importlib import importlib
import os.path as osp import os.path as osp
import torch.nn as nn from mmengine.config import Config
from mmengine.config.utils import (_get_cfg_metainfo,
_get_external_cfg_base_path,
_get_package_and_cfg_path)
from mmengine.registry import MODELS, DefaultScope from mmengine.registry import MODELS, DefaultScope
from mmengine.runner import load_checkpoint
from mmengine.utils import check_install_package, get_installed_path from mmengine.utils import check_install_package, get_installed_path
from .config import Config
from .utils import (_get_cfg_metainfo, _get_external_cfg_base_path,
_get_package_and_cfg_path)
def get_config(cfg_path: str, pretrained: bool = False) -> Config: def get_config(cfg_path: str, pretrained: bool = False) -> Config:
...@@ -56,7 +56,7 @@ def get_config(cfg_path: str, pretrained: bool = False) -> Config: ...@@ -56,7 +56,7 @@ def get_config(cfg_path: str, pretrained: bool = False) -> Config:
return cfg return cfg
def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module: def get_model(cfg_path: str, pretrained: bool = False, **kwargs):
"""Get built model from external package. """Get built model from external package.
Args: Args:
...@@ -68,7 +68,6 @@ def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module: ...@@ -68,7 +68,6 @@ def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
Returns: Returns:
nn.Module: Built model. nn.Module: Built model.
""" """
import mmengine.runner
package = cfg_path.split('::')[0] package = cfg_path.split('::')[0]
with DefaultScope.overwrite_default_scope(package): # type: ignore with DefaultScope.overwrite_default_scope(package): # type: ignore
cfg = get_config(cfg_path, pretrained) cfg = get_config(cfg_path, pretrained)
...@@ -76,5 +75,5 @@ def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module: ...@@ -76,5 +75,5 @@ def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
models_module.register_all_modules() # type: ignore models_module.register_all_modules() # type: ignore
model = MODELS.build(cfg.model, default_args=kwargs) model = MODELS.build(cfg.model, default_args=kwargs)
if pretrained: if pretrained:
mmengine.runner.load_checkpoint(model, cfg.model_path) load_checkpoint(model, cfg.model_path)
return model return model
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .history_buffer import HistoryBuffer from .history_buffer import HistoryBuffer
from .log_processor import LogProcessor
from .logger import MMLogger, print_log from .logger import MMLogger, print_log
from .message_hub import MessageHub from .message_hub import MessageHub
__all__ = [ __all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']
'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor'
]
...@@ -7,7 +7,6 @@ from typing import Optional, Union ...@@ -7,7 +7,6 @@ from typing import Optional, Union
from termcolor import colored from termcolor import colored
from mmengine.dist import get_rank
from mmengine.utils import ManagerMixin from mmengine.utils import ManagerMixin
from mmengine.utils.manager import _accquire_lock, _release_lock from mmengine.utils.manager import _accquire_lock, _release_lock
...@@ -152,7 +151,8 @@ class MMLogger(Logger, ManagerMixin): ...@@ -152,7 +151,8 @@ class MMLogger(Logger, ManagerMixin):
Logger.__init__(self, logger_name) Logger.__init__(self, logger_name)
ManagerMixin.__init__(self, name) ManagerMixin.__init__(self, name)
# Get rank in DDP mode. # Get rank in DDP mode.
rank = get_rank()
rank = _get_rank()
# Config stream_handler. If `rank != 0`. stream_handler can only # Config stream_handler. If `rank != 0`. stream_handler can only
# export ERROR logs. # export ERROR logs.
...@@ -289,3 +289,14 @@ def print_log(msg, ...@@ -289,3 +289,14 @@ def print_log(msg,
raise TypeError( raise TypeError(
'`logger` should be either a logging.Logger object, str, ' '`logger` should be either a logging.Logger object, str, '
f'"silent", "current" or None, but got {type(logger)}') f'"silent", "current" or None, but got {type(logger)}')
def _get_rank():
"""Support using logging module without torch."""
try:
# requires torch
from mmengine.dist import get_rank
except ImportError:
return 0
else:
return get_rank()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment