diff --git a/mmengine/config/__init__.py b/mmengine/config/__init__.py
index abf09ab2f2ddefac91f3bc9fcc35035c7281297d..bdf470671a743ede1a86909b709b79b7bd3e467e 100644
--- a/mmengine/config/__init__.py
+++ b/mmengine/config/__init__.py
@@ -1,4 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .config import Config, ConfigDict, DictAction
+from .get_config_model import get_config, get_model
 
-__all__ = ['Config', 'ConfigDict', 'DictAction']
+__all__ = ['Config', 'ConfigDict', 'DictAction', 'get_config', 'get_model']
diff --git a/mmengine/config/config.py b/mmengine/config/config.py
index 0e06580261a3a61a7e934b5390d46a7c8848fc8b..18c7694336125eeccff9842d27ab1c207d351467 100644
--- a/mmengine/config/config.py
+++ b/mmengine/config/config.py
@@ -12,14 +12,16 @@ import warnings
 from argparse import Action, ArgumentParser, Namespace
 from collections import abc
 from pathlib import Path
-from typing import Any, List, Optional, Sequence, Tuple, Union
+from typing import Any, Optional, Sequence, Tuple, Union
 
 from addict import Dict
 from yapf.yapflib.yapf_api import FormatCode
 
 from mmengine.fileio import dump, load
-from mmengine.utils import check_file_exist, import_modules_from_strings
-from .utils import RemoveAssignFromAST
+from mmengine.utils import (check_file_exist, check_install_package,
+                            get_installed_path, import_modules_from_strings)
+from .utils import (RemoveAssignFromAST, _get_external_cfg_base_path,
+                    _get_external_cfg_path, _get_package_and_cfg_path)
 
 BASE_KEY = '_base_'
 DELETE_KEY = '_delete_'
@@ -393,16 +395,26 @@ class Config:
             # Handle base files
             base_cfg_dict = ConfigDict()
             cfg_text_list = list()
-            for base_cfg_path in Config._parse_base_files(
-                    temp_config_file.name):
-                cfg_dir = osp.dirname(filename)
-                _cfg_dict, _cfg_text = Config._file2dict(
-                    osp.join(cfg_dir, base_cfg_path))
+            for base_cfg_path in Config._get_base_files(temp_config_file.name):
+                base_cfg_path, scope = Config._get_cfg_path(
+                    base_cfg_path, filename)
+                _cfg_dict, _cfg_text = Config._file2dict(base_cfg_path)
                 cfg_text_list.append(_cfg_text)
                 duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys()
                 if len(duplicate_keys) > 0:
                     raise KeyError('Duplicate key is not allowed among bases. '
                                    f'Duplicate keys: {duplicate_keys}')
+
+                # _dict_to_config_dict will do the following things:
+                # 1. Recursively converts ``dict`` to :obj:`ConfigDict`.
+                # 2. Set `_scope_` for the outer dict variable for the base
+                # config.
+                # 3. Set `scope` attribute for each base variable. Different
+                # from `_scope_`, `scope` is not a key of base dict,
+                # `scope` attribute will be parsed to key `_scope_` by
+                # function `_parse_scope` only if the base variable is
+                # accessed by the current config.
+                _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope)
                 base_cfg_dict.update(_cfg_dict)
 
             if filename.endswith('.py'):
@@ -428,6 +440,11 @@ class Config:
                     cfg_dict.pop(key)
             temp_config_file.close()
 
+            # If the current config accesses a base variable of base
+            # configs, The ``scope`` attribute of corresponding variable
+            # will be converted to the `_scope_`.
+            Config._parse_scope(cfg_dict)
+
         # check deprecation information
         if DEPRECATION_KEY in cfg_dict:
             deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
@@ -464,19 +481,76 @@ class Config:
         return cfg_dict, cfg_text
 
     @staticmethod
-    def _parse_base_files(file_path: str) -> List[str]:
-        """Get paths of all base config files.
+    def _dict_to_config_dict(cfg: dict,
+                             scope: Optional[str] = None,
+                             has_scope=True):
+        """Recursively converts ``dict`` to :obj:`ConfigDict`.
 
         Args:
-            file_path (str): Path of config.
+            cfg (dict): Config dict.
+            scope (str, optional): Scope of instance.
+            has_scope (bool): Whether to add `_scope_` key to config dict.
 
         Returns:
-            List[str]: paths of all base files .
+            ConfigDict: Converted dict.
+        """
+        # Only the outer dict with key `type` should have the key `_scope_`.
+        if isinstance(cfg, dict):
+            if has_scope and 'type' in cfg:
+                has_scope = False
+                if scope is not None and cfg.get('_scope_', None) is None:
+                    cfg._scope_ = scope  # type: ignore
+            cfg = ConfigDict(cfg)
+            dict.__setattr__(cfg, 'scope', scope)
+            for key, value in cfg.items():
+                cfg[key] = Config._dict_to_config_dict(
+                    value, scope=scope, has_scope=has_scope)
+        elif isinstance(cfg, tuple):
+            cfg = tuple(
+                Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope)
+                for _cfg in cfg)
+        elif isinstance(cfg, list):
+            cfg = [
+                Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope)
+                for _cfg in cfg
+            ]
+        return cfg
+
+    @staticmethod
+    def _parse_scope(cfg: dict) -> None:
+        """Adds ``_scope_`` to :obj:`ConfigDict` instance, which means a base
+        variable.
+
+        If the config dict already has the scope, scope will not be
+        overwritten.
+
+        Args:
+            cfg (dict): Config needs to be parsed with scope.
         """
-        file_format = file_path.partition('.')[-1]
+        if isinstance(cfg, ConfigDict):
+            cfg._scope_ = cfg.scope
+        elif isinstance(cfg, (tuple, list)):
+            [Config._parse_scope(value) for value in cfg]
+        else:
+            return
+
+    @staticmethod
+    def _get_base_files(filename: str) -> list:
+        """Get the base config file.
+
+        Args:
+            filename (str): The config file.
+
+        Raises:
+            TypeError: Name of config file.
+
+        Returns:
+            list: A list of base config
+        """
+        file_format = filename.partition('.')[-1]
         if file_format == 'py':
-            Config._validate_py_syntax(file_path)
-            with open(file_path) as f:
+            Config._validate_py_syntax(filename)
+            with open(filename) as f:
                 codes = ast.parse(f.read()).body
 
                 def is_base_line(c):
@@ -492,7 +566,8 @@ class Config:
                 else:
                     base_files = []
         elif file_format in ('yml', 'yaml', 'json'):
-            cfg_dict = load(file_path)
+            import mmengine
+            cfg_dict = mmengine.load(filename)
             base_files = cfg_dict.get(BASE_KEY, [])
         else:
             raise TypeError('The config type should be py, json, yaml or '
@@ -501,6 +576,43 @@ class Config:
                                               list) else [base_files]
         return base_files
 
+    @staticmethod
+    def _get_cfg_path(cfg_path: str,
+                      filename: str) -> Tuple[str, Optional[str]]:
+        """Get the config path from the current or external package.
+
+        Args:
+            cfg_path (str): Relative path of config.
+            filename (str): The config file being parsed.
+
+        Returns:
+            Tuple[str, str or None]: Path and scope of config. If the config
+            is not an external config, the scope will be `None`.
+        """
+        if '::' in cfg_path:
+            # `cfg_path` startswith '::' means an external config path.
+            # Get package name and relative config path.
+            scope = cfg_path.partition('::')[0]
+            package, cfg_path = _get_package_and_cfg_path(cfg_path)
+            # Get installed package path.
+            check_install_package(package)
+            package_path = get_installed_path(package)
+            try:
+                # Get config path from meta file.
+                cfg_path = _get_external_cfg_path(package_path, cfg_path)
+            except ValueError:
+                # Since base config does not have a metafile, it should be
+                # concatenated with package path and relative config path.
+                cfg_path = _get_external_cfg_base_path(package_path, cfg_path)
+            except FileNotFoundError as e:
+                raise e
+            return cfg_path, scope
+        else:
+            # Get local config path.
+            cfg_dir = osp.dirname(filename)
+            cfg_path = osp.join(cfg_dir, cfg_path)
+            return cfg_path, None
+
     @staticmethod
     def _merge_a_into_b(a: dict,
                         b: dict,
diff --git a/mmengine/config/get_config_model.py b/mmengine/config/get_config_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..14b5936a5b6280c901ca6ba76a54ba91e883d1ce
--- /dev/null
+++ b/mmengine/config/get_config_model.py
@@ -0,0 +1,80 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os.path as osp
+
+import torch.nn as nn
+
+from mmengine.registry import MODELS, DefaultScope
+from mmengine.utils import check_install_package, get_installed_path
+from .config import Config
+from .utils import (_get_cfg_metainfo, _get_external_cfg_base_path,
+                    _get_package_and_cfg_path)
+
+
+def get_config(cfg_path: str, pretrained: bool = False) -> Config:
+    """Get config from external package.
+
+    Args:
+        cfg_path (str): External relative config path.
+        pretrained (bool): Whether to save pretrained model path. If
+            ``pretrained==True``, the url of pretrained model can be accessed
+            by ``cfg.model_path``. Defaults to False.
+
+    Examples:
+        >>> cfg = get_config('mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco',
+        >>>                  pretrained=True)
+        >>> # Equivalent to
+        >>> Config.fromfile('/path/to/faster_rcnn_r50_fpn_1x_coco.py')
+        >>> cfg.model_path
+        https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
+
+    Returns:
+        Config: A `Config` parsed from external package.
+    """  # noqa E301
+    # Get package name and relative config path.
+    package, cfg_path = _get_package_and_cfg_path(cfg_path)
+    # Check package is installed.
+    check_install_package(package)
+    package_path = get_installed_path(package)
+    try:
+        # Use `cfg_path` to search target config file.
+        cfg_meta = _get_cfg_metainfo(package_path, cfg_path)
+        cfg_path = osp.join(package_path, '.mim', cfg_meta['Config'])
+        cfg = Config.fromfile(cfg_path)
+        if pretrained:
+            assert 'Weights' in cfg_meta, ('Cannot find `Weights` in cfg_file'
+                                           '.metafile.yml, please check the'
+                                           'metafile')
+            cfg.model_path = cfg_meta['Weights']
+    except ValueError:
+        # Since the base config does not contain a metafile, the absolute
+        # config is `osp.join(package_path, cfg_path_prefix, cfg_name)`
+        cfg_path = _get_external_cfg_base_path(package_path, cfg_path)
+        cfg = Config.fromfile(cfg_path)
+    except Exception as e:
+        raise e
+    return cfg
+
+
+def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
+    """Get built model from external package.
+
+    Args:
+        cfg_path (str): External relative config path with prefix
+            'package::' and without suffix.
+        pretrained (bool): Whether to load pretrained model. Defaults to False.
+        kwargs (dict): Default arguments to build model.
+
+    Returns:
+        nn.Module: Built model.
+    """
+    import mmengine.runner
+    package = cfg_path.split('::')[0]
+    with DefaultScope.overwrite_default_scope(package):  # type: ignore
+        cfg = get_config(cfg_path, pretrained)
+        models_module = importlib.import_module(f'{package}.utils')
+        models_module.register_all_modules()  # type: ignore
+        model = MODELS.build(cfg.model, default_args=kwargs)
+        if pretrained:
+            mmengine.runner.load_checkpoint(model, cfg.model_path)
+        return model
diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py
index 54c32691f71a2cca051eaa8b65f2cf6aab143657..7eec2c653bac757e1e7dc19fd5c3871070692b2e 100644
--- a/mmengine/config/utils.py
+++ b/mmengine/config/utils.py
@@ -1,5 +1,122 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import ast
+import os.path as osp
+import re
+import warnings
+from typing import Tuple
+
+from mmengine.fileio import load
+from mmengine.utils import check_file_exist
+
+PKG2PROJECT = {
+    'mmcls': 'mmcls',
+    'mmdet': 'mmdet',
+    'mmdet3d': 'mmdet3d',
+    'mmseg': 'mmsegmentation',
+    'mmaction2': 'mmaction2',
+    'mmtrack': 'mmtrack',
+    'mmpose': 'mmpose',
+    'mmedit': 'mmedit',
+    'mmocr': 'mmocr',
+    'mmgen': 'mmgen',
+    'mmfewshot': 'mmfewshot',
+    'mmrazor': 'mmrazor',
+    'mmflow': 'mmflow',
+    'mmhuman3d': 'mmhuman3d',
+    'mmrotate': 'mmrotate',
+    'mmselfsup': 'mmselfsup',
+}
+
+
+def _get_cfg_metainfo(package_path: str, cfg_path: str) -> dict:
+    """Get target meta information from all 'metafile.yml' defined in `mode-
+    index.yml` of external package.
+
+    Args:
+        package_path (str): Path of external package.
+        cfg_path (str): Name of experiment config.
+
+    Returns:
+        dict: Meta information of target experiment.
+    """
+    meta_index_path = osp.join(package_path, '.mim', 'model-index.yml')
+    meta_index = load(meta_index_path)
+    cfg_dict = dict()
+    for meta_path in meta_index['Import']:
+        meta_path = osp.join(package_path, '.mim', meta_path)
+        cfg_meta = load(meta_path)
+        for model_cfg in cfg_meta['Models']:
+            if 'Config' not in model_cfg:
+                warnings.warn(f'There is not `Config` define in {model_cfg}')
+                continue
+            cfg_name = model_cfg['Config'].partition('/')[-1]
+            # Some config could have multiple weights, we only pick the
+            # first one.
+            if cfg_name in cfg_dict:
+                continue
+            cfg_dict[cfg_name] = model_cfg
+    if cfg_path not in cfg_dict:
+        raise ValueError(f'Expected configs: {cfg_dict.keys()}, but got '
+                         f'{cfg_path}')
+    return cfg_dict[cfg_path]
+
+
+def _get_external_cfg_path(package_path: str, cfg_file: str) -> str:
+    """Get config path of external package.
+
+    Args:
+        package_path (str): Path of external package.
+        cfg_file (str): Name of experiment config.
+
+    Returns:
+        str: Absolute config path from external package.
+    """
+    cfg_file = cfg_file.split('.')[0]
+    model_cfg = _get_cfg_metainfo(package_path, cfg_file)
+    cfg_path = osp.join(package_path, model_cfg['Config'])
+    check_file_exist(cfg_path)
+    return cfg_path
+
+
+def _get_external_cfg_base_path(package_path: str, cfg_name: str) -> str:
+    """Get base config path of external package.
+
+    Args:
+        package_path (str): Path of external package.
+        cfg_name (str): External relative config path with 'package::'.
+
+    Returns:
+        str: Absolute config path from external package.
+    """
+    cfg_path = osp.join(package_path, '.mim', 'configs', cfg_name)
+    check_file_exist(cfg_path)
+    return cfg_path
+
+
+def _get_package_and_cfg_path(cfg_path: str) -> Tuple[str, str]:
+    """Get package name and relative config path.
+
+    Args:
+        cfg_path (str): External relative config path with 'package::'.
+
+    Returns:
+        Tuple[str, str]: Package name and config path.
+    """
+    if re.match(r'\w*::\w*/\w*', cfg_path) is None:
+        raise ValueError(
+            '`_get_package_and_cfg_path` is used for get external package, '
+            'please specify the package name and relative config path, just '
+            'like `mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py`')
+    package_cfg = cfg_path.split('::')
+    if len(package_cfg) > 2:
+        raise ValueError('`::` should only be used to separate package and '
+                         'config name, but found multiple `::` in '
+                         f'{cfg_path}')
+    package, cfg_path = package_cfg
+    assert package in PKG2PROJECT, 'mmengine does not support to load ' \
+                                   f'{package} config.'
+    package = PKG2PROJECT[package]
+    return package, cfg_path
 
 
 class RemoveAssignFromAST(ast.NodeTransformer):
diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py
index 351aabeea6aa8fa2eb11eee5c262e8f6d8790c1d..5fc4e56f95947d45a1bb022098fdefdf270b412a 100644
--- a/mmengine/registry/registry.py
+++ b/mmengine/registry/registry.py
@@ -4,8 +4,10 @@ import logging
 import sys
 from collections.abc import Callable
 from contextlib import contextmanager
+from importlib import import_module
 from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union
 
+from ..config.utils import PKG2PROJECT
 from ..utils import is_seq_of
 from .default_scope import DefaultScope
 
@@ -239,6 +241,13 @@ class Registry:
             # Get registry by scope
             if default_scope is not None:
                 scope_name = default_scope.scope_name
+                if scope_name in PKG2PROJECT:
+                    try:
+                        module = import_module(
+                            f'{PKG2PROJECT[scope_name]}.utils')
+                        module.register_all_modules()  # type: ignore
+                    except ImportError as e:
+                        raise e
                 root = self._get_root_registry()
                 registry = root._search_child(scope_name)
                 if registry is None:
diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py
index df0803786bcb142228d0b5d641024f7a54d5fb60..51ad2e413169d21973d0ce931435d775362e1af3 100644
--- a/mmengine/utils/__init__.py
+++ b/mmengine/utils/__init__.py
@@ -8,6 +8,8 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
                    requires_executable, requires_package, slice_list,
                    to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
                    tuple_cast)
+from .package_utils import (call_command, check_install_package,
+                            get_installed_path, is_installed)
 from .parrots_wrapper import TORCH_VERSION
 from .path import (check_file_exist, fopen, is_abs, is_filepath,
                    mkdir_or_exist, scandir, symlink)
@@ -28,5 +30,6 @@ __all__ = [
     'is_method_overridden', 'has_method', 'mmcv_full_available',
     'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
     'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
-    'is_abs', 'revert_sync_batchnorm'
+    'is_abs', 'is_installed', 'call_command', 'get_installed_path',
+    'check_install_package', 'is_abs', 'revert_sync_batchnorm'
 ]
diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfd41eb179eecbe9b248624c80ef547377bfed2a
--- /dev/null
+++ b/mmengine/utils/package_utils.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os.path as osp
+import subprocess
+
+import pkg_resources
+from pkg_resources import get_distribution
+
+
+def is_installed(package: str) -> bool:
+    """Check package whether installed.
+
+    Args:
+        package (str): Name of package to be checked.
+    """
+    # refresh the pkg_resources
+    # more datails at https://github.com/pypa/setuptools/issues/373
+    importlib.reload(pkg_resources)
+    try:
+        get_distribution(package)
+        return True
+    except pkg_resources.DistributionNotFound:
+        return False
+
+
+def get_installed_path(package: str) -> str:
+    """Get installed path of package.
+
+    Args:
+        package (str): Name of package.
+
+    Example:
+        >>> get_installed_path('mmcls')
+        >>> '.../lib/python3.7/site-packages/mmcls'
+    """
+    # if the package name is not the same as module name, module name should be
+    # inferred. For example, mmcv-full is the package name, but mmcv is module
+    # name. If we want to get the installed path of mmcv-full, we should concat
+    # the pkg.location and module name
+    pkg = get_distribution(package)
+    possible_path = osp.join(pkg.location, package)
+    if osp.exists(possible_path):
+        return possible_path
+    else:
+        return osp.join(pkg.location, package2module(package))
+
+
+def package2module(package: str):
+    """Infer module name from package.
+
+    Args:
+        package (str): Package to infer module name.
+    """
+    pkg = get_distribution(package)
+    if pkg.has_metadata('top_level.txt'):
+        module_name = pkg.get_metadata('top_level.txt').split('\n')[0]
+        return module_name
+    else:
+        raise ValueError(f'can not infer the module name of {package}')
+
+
+def call_command(cmd: list) -> None:
+    try:
+        subprocess.check_call(cmd)
+    except Exception as e:
+        raise e  # type: ignore
+
+
+def check_install_package(package: str):
+    if not is_installed(package):
+        call_command(['python', '-m', 'pip', 'install', package])
diff --git a/tests/data/config/py_config/test_get_external_cfg.py b/tests/data/config/py_config/test_get_external_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a72f30feb59988874160ed2529290d49416c73
--- /dev/null
+++ b/tests/data/config/py_config/test_get_external_cfg.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = [
+    'mmdet::_base_/models/faster_rcnn_r50_fpn.py',
+    'mmdet::_base_/datasets/coco_detection.py',
+    'mmdet::_base_/schedules/schedule_1x.py',
+    'mmdet::_base_/default_runtime.py'
+]
diff --git a/tests/data/config/py_config/test_get_external_cfg2.py b/tests/data/config/py_config/test_get_external_cfg2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2d60ceeb9888b3a3979dfb0ddff0473a75482eb
--- /dev/null
+++ b/tests/data/config/py_config/test_get_external_cfg2.py
@@ -0,0 +1,2 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = 'mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
diff --git a/tests/data/config/py_config/test_get_external_cfg3.py b/tests/data/config/py_config/test_get_external_cfg3.py
new file mode 100644
index 0000000000000000000000000000000000000000..7313b8038c1ab2b32b617e2d9d25f0d9f5cd2a54
--- /dev/null
+++ b/tests/data/config/py_config/test_get_external_cfg3.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = [
+    'mmdet::_base_/models/faster_rcnn_r50_fpn.py',
+    'mmdet::_base_/datasets/coco_detection.py',
+    'mmdet::_base_/schedules/schedule_1x.py',
+    'mmdet::_base_/default_runtime.py',
+    './test_get_external_cfg_base.py'
+]
+
+custom_hooks = [dict(type='mmdet.DetVisualizationHook')]
+
+model = dict(
+    roi_head=dict(
+        bbox_head=dict(
+            loss_cls=dict(_delete_=True, type='test.ToyLoss')
+        )
+    )
+)
diff --git a/tests/data/config/py_config/test_get_external_cfg_base.py b/tests/data/config/py_config/test_get_external_cfg_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d680ef0a6b8d9adf8edbddd1ef54cdd0dc0370c4
--- /dev/null
+++ b/tests/data/config/py_config/test_get_external_cfg_base.py
@@ -0,0 +1,2 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+toy_model = dict(type='ToyModel')
diff --git a/tests/test_config/test_collect_meta.py b/tests/test_config/test_collect_meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7b5e73a2788908a01e5b929245914f399739699
--- /dev/null
+++ b/tests/test_config/test_collect_meta.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path
+
+import pytest
+
+from mmengine.config.utils import (_get_external_cfg_base_path,
+                                   _get_package_and_cfg_path)
+
+
+def test_get_external_cfg_base_path(tmp_path):
+    package_path = tmp_path
+    rel_cfg_path = 'cfg_dir/cfg_file'
+    with pytest.raises(FileNotFoundError):
+        _get_external_cfg_base_path(str(package_path), rel_cfg_path)
+    cfg_dir = tmp_path / '.mim' / 'configs' / 'cfg_dir'
+    cfg_dir.mkdir(parents=True, exist_ok=True)
+    f = open(cfg_dir / 'cfg_file', 'w')
+    f.close()
+    cfg_path = _get_external_cfg_base_path(str(package_path), rel_cfg_path)
+    assert cfg_path == f'{os.path.join(str(cfg_dir), "cfg_file")}'
+
+
+def test_get_external_cfg_path():
+    external_cfg_path = 'mmdet::path/cfg'
+    package, rel_cfg_path = _get_package_and_cfg_path(external_cfg_path)
+    assert package == 'mmdet'
+    assert rel_cfg_path == 'path/cfg'
+    # external config must contain `::`.
+    external_cfg_path = 'path/cfg'
+    with pytest.raises(ValueError):
+        _get_package_and_cfg_path(external_cfg_path)
+    # Use `:::` as operator will raise an error.
+    external_cfg_path = 'mmdet:::path/cfg'
+    with pytest.raises(ValueError):
+        _get_package_and_cfg_path(external_cfg_path)
+    # Use `:` as operator will raise an error.
+    external_cfg_path = 'mmdet:path/cfg'
+    with pytest.raises(ValueError):
+        _get_package_and_cfg_path(external_cfg_path)
+    # Too much `::`
+    external_cfg_path = 'mmdet::path/cfg::error'
+    with pytest.raises(ValueError):
+        _get_package_and_cfg_path(external_cfg_path)
diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py
index 671bdc1283e6d5fb22abc9db851ef343e6417129..8c33856cc163a5569c933eccf915438b36c04b7b 100644
--- a/tests/test_config/test_config.py
+++ b/tests/test_config/test_config.py
@@ -12,6 +12,8 @@ import pytest
 
 from mmengine import Config, ConfigDict, DictAction
 from mmengine.fileio import dump, load
+from mmengine.registry import MODELS, DefaultScope, Registry
+from mmengine.utils import is_installed
 
 
 class TestConfig:
@@ -186,6 +188,21 @@ class TestConfig:
         #  overwritten by int
         sys.argv.extend(tmp)
 
+    def test_dict_to_config_dict(self):
+        cfg_dict = dict(
+            a=1, b=dict(c=dict()), d=[dict(e=dict(f=(dict(g=1), [])))])
+        cfg_dict = Config._dict_to_config_dict(cfg_dict)
+        assert isinstance(cfg_dict, ConfigDict)
+        assert isinstance(cfg_dict.a, int)
+        assert isinstance(cfg_dict.b, ConfigDict)
+        assert isinstance(cfg_dict.b.c, ConfigDict)
+        assert isinstance(cfg_dict.d, list)
+        assert isinstance(cfg_dict.d[0], ConfigDict)
+        assert isinstance(cfg_dict.d[0].e, ConfigDict)
+        assert isinstance(cfg_dict.d[0].e.f, tuple)
+        assert isinstance(cfg_dict.d[0].e.f[0], ConfigDict)
+        assert isinstance(cfg_dict.d[0].e.f[1], list)
+
     def test_dump(self, tmp_path):
         file_path = 'config/py_config/test_merge_from_multiple_bases.py'
         cfg_file = osp.join(self.data_path, file_path)
@@ -266,7 +283,7 @@ class TestConfig:
             print(cfg, file=f)
         with open(tmp_txt) as f:
             assert f.read().strip() == f'Config (path: {cfg.filename}): ' \
-                               f'{cfg._cfg_dict.__repr__()}'
+                                       f'{cfg._cfg_dict.__repr__()}'
 
     def test_dict_action(self):
         parser = argparse.ArgumentParser(description='Train a detector')
@@ -418,6 +435,26 @@ class TestConfig:
         self._merge_recursive_bases()
         self._deprecation()
 
+    def test_get_cfg_path(self):
+        filename = 'py_config/simple_config.py'
+        filename = osp.join(self.data_path, 'config', filename)
+        cfg_name = './base.py'
+        cfg_path, scope = Config._get_cfg_path(cfg_name, filename)
+        assert scope is None
+        osp.isfile(cfg_path)
+
+        # Test scope equal to package name.
+        cfg_name = 'mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
+        cfg_path, scope = Config._get_cfg_path(cfg_name, filename)
+        assert scope == 'mmdet'
+        osp.isfile(cfg_path)
+
+        # Test scope does not equal to package name.
+        cfg_name = 'mmcls::cspnet/cspresnet50_8xb32_in1k.py'
+        cfg_path, scope = Config._get_cfg_path(cfg_name, filename)
+        assert scope == 'mmcls'
+        osp.isfile(cfg_path)
+
     def _simple_load(self):
         # test load simple config
         for file_format in ['py', 'json', 'yaml']:
@@ -750,3 +787,57 @@ class TestConfig:
         assert new_cfg._cfg_dict is cfg._cfg_dict
         assert new_cfg._filename == cfg._filename
         assert new_cfg._text == cfg._text
+
+    def test_get_external_cfg(self):
+        ext_cfg_path = osp.join(self.data_path,
+                                'config/py_config/test_get_external_cfg.py')
+        ext_cfg = Config.fromfile(ext_cfg_path)
+        assert ext_cfg._cfg_dict.model.neck == dict(
+            type='FPN',
+            in_channels=[256, 512, 1024, 2048],
+            out_channels=256,
+            num_outs=5,
+        )
+        assert '_scope_' in ext_cfg._cfg_dict.model
+
+    @pytest.mark.skipif(
+        not is_installed('mmdet'), reason='mmdet should be installed')
+    def test_build_external_package(self):
+        # Test load base config.
+        ext_cfg_path = osp.join(self.data_path,
+                                'config/py_config/test_get_external_cfg.py')
+        ext_cfg = Config.fromfile(ext_cfg_path)
+
+        LOCAL_MODELS = Registry('local_model', parent=MODELS, scope='test')
+        LOCAL_MODELS.build(ext_cfg.model)
+
+        # Test load non-base config
+        ext_cfg_path = osp.join(self.data_path,
+                                'config/py_config/test_get_external_cfg2.py')
+        ext_cfg = Config.fromfile(ext_cfg_path)
+        LOCAL_MODELS.build(ext_cfg.model)
+
+        # Test override base variable.
+        ext_cfg_path = osp.join(self.data_path,
+                                'config/py_config/test_get_external_cfg3.py')
+        ext_cfg = Config.fromfile(ext_cfg_path)
+
+        @LOCAL_MODELS.register_module()
+        class ToyLoss:
+            pass
+
+        @LOCAL_MODELS.register_module()
+        class ToyModel:
+            pass
+
+        DefaultScope.get_instance('test1', scope_name='test')
+        assert ext_cfg.model._scope_ == 'mmdet'
+        model = LOCAL_MODELS.build(ext_cfg.model)
+
+        # Local base config should not have scope.
+        assert '_scope_' not in ext_cfg.toy_model
+        toy_model = LOCAL_MODELS.build(ext_cfg.toy_model)
+        assert isinstance(toy_model, ToyModel)
+        assert model.backbone.style == 'pytorch'
+        assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss)
+        DefaultScope._instance_dict.pop('test1')
diff --git a/tests/test_config/test_get_config_model.py b/tests/test_config/test_get_config_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..812ab8242a1a2ce4e5925e297546a887b137e922
--- /dev/null
+++ b/tests/test_config/test_get_config_model.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+import pytest
+
+from mmengine import Config, DefaultScope, get_config, get_model
+from mmengine.utils import get_installed_path, is_installed
+
+data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data/')
+
+
+# mmdet has a more typical config structure, while mmpose has a complex
+# config structure
+@pytest.mark.skipif(
+    not (is_installed('mmdet') and is_installed('mmpose')),
+    reason='mmdet and mmpose should be installed')
+def test_get_config():
+    # Test load base config.
+    base_cfg = get_config('mmdet::_base_/models/faster_rcnn_r50_fpn.py')
+    package_path = get_installed_path('mmdet')
+    test_base_cfg = Config.fromfile(
+        osp.join(package_path, '.mim',
+                 'configs/_base_/models/faster_rcnn_r50_fpn.py'))
+    assert test_base_cfg._cfg_dict == base_cfg._cfg_dict
+
+    # Test load faster_rcnn config
+    cfg = get_config('mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')
+    test_cfg = Config.fromfile(
+        osp.join(package_path, '.mim',
+                 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'))
+    assert cfg._cfg_dict == test_cfg._cfg_dict
+
+    # Test pretrained
+    cfg = get_config(
+        'mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py', pretrained=True)
+    assert cfg.model_path == 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'  # noqa E301
+
+    # Test load mmpose
+    get_config(
+        'mmpose::face/2d_kpt_sview_rgb_img/deeppose/wflw/res50_wflw_256x256'
+        '.py')
+
+
+@pytest.mark.skipif(
+    not is_installed('mmdet'), reason='mmdet and mmpose should be installed')
+def test_get_model():
+    # TODO compatible with downstream codebase.
+    DefaultScope.get_instance('test_get_model', scope_name='test_scope')
+    get_model('mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')
+    assert DefaultScope.get_current_instance().scope_name == 'test_scope'
+    DefaultScope._instance_dict.pop('test_get_model')
diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py
index 74e51e22818693c445c5ee00ac657ec8b8863e7a..d11fac13f69e63f97f08e973e66131568fce736b 100644
--- a/tests/test_registry/test_registry.py
+++ b/tests/test_registry/test_registry.py
@@ -193,7 +193,7 @@ class TestRegistry:
 
         return registries
 
-    def test_get_root_registry(self):
+    def test__get_root_registry(self):
         #        Hierarchical Registry
         #                           DOGS
         #                      _______|_______
@@ -304,7 +304,7 @@ class TestRegistry:
         assert DOGS.get('samoyed.LittlePedigreeSamoyed') is None
         assert LITTLE_HOUNDS.get('mid_hound.PedigreeSamoyedddddd') is None
 
-    def test_search_child(self):
+    def test__search_child(self):
         #        Hierarchical Registry
         #                           DOGS
         #                      _______|_______
diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py
index cc8fa2b844d4ac095daa1997a888d6ed08401d1c..a487501fc6cfd0f4b937c4ddc25601ab0311b1d1 100644
--- a/tests/test_visualizer/test_visualizer.py
+++ b/tests/test_visualizer/test_visualizer.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import copy
+import time
 from typing import Any
 from unittest import TestCase
 
@@ -117,12 +118,13 @@ class TestVisualizer(TestCase):
                 save_dir='temp_dir')
 
         # test global init
+        instance_name = 'visualizer' + str(time.time())
         visualizer = Visualizer.get_instance(
-            'visualizer',
+            instance_name,
             vis_backends=copy.deepcopy(self.vis_backend_cfg),
             save_dir='temp_dir')
         assert len(visualizer._vis_backends) == 2
-        visualizer_any = Visualizer.get_instance('visualizer')
+        visualizer_any = Visualizer.get_instance(instance_name)
         assert visualizer_any == visualizer
 
     def test_set_image(self):