From ed84dfd34db4994697bbb88f035c1577c70abb8f Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Mon, 26 Sep 2022 14:30:40 +0800 Subject: [PATCH] [Refactor] Refactor fileio without breaking back compatibility (#533) * [Refactor] Refactor fileio but without breaking bc * handle compatibility * fix format * modify io functions * fix ut * fix ut * rename method names * refine * refine docstring * fix ut in windows * update ut * minor fix * ensure client is not None when closing it * add more examples for list_dir_or_file interface * refine docstring * refine deprecated info * fix ut * add a description for lmdb docstring --- docs/en/api/fileio.rst | 37 +- docs/zh_cn/api/fileio.rst | 37 +- mmengine/fileio/__init__.py | 29 +- mmengine/fileio/backends/__init__.py | 14 + mmengine/fileio/backends/base.py | 36 + mmengine/fileio/backends/http_backend.py | 78 ++ mmengine/fileio/backends/lmdb_backend.py | 82 ++ mmengine/fileio/backends/local_backend.py | 543 +++++++++++ mmengine/fileio/backends/memcached_backend.py | 58 ++ mmengine/fileio/backends/petrel_backend.py | 768 +++++++++++++++ mmengine/fileio/backends/registry_utils.py | 117 +++ mmengine/fileio/file_client.py | 714 +------------- mmengine/fileio/handlers/__init__.py | 6 +- mmengine/fileio/handlers/registry_utils.py | 42 + mmengine/fileio/io.py | 883 +++++++++++++++++- mmengine/fileio/parse.py | 61 +- mmengine/hooks/checkpoint_hook.py | 57 +- mmengine/hooks/logger_hook.py | 37 +- mmengine/runner/checkpoint.py | 69 +- mmengine/runner/runner.py | 44 +- requirements/tests.txt | 1 + .../test_backends/test_backend_utils.py | 114 +++ .../test_base_storage_backend.py | 27 + .../test_backends/test_http_backend.py | 51 + .../test_backends/test_lmdb_backend.py | 35 + .../test_backends/test_local_backend.py | 486 ++++++++++ .../test_backends/test_memcached_backend.py | 59 ++ .../test_backends/test_petrel_backend.py | 858 +++++++++++++++++ tests/test_fileio/test_fileclient.py | 10 +- tests/test_fileio/test_fileio.py | 25 +- tests/test_fileio/test_io.py | 532 +++++++++++ tests/test_hooks/test_checkpoint_hook.py | 72 +- tests/test_hooks/test_logger_hook.py | 16 + 33 files changed, 5139 insertions(+), 859 deletions(-) create mode 100644 mmengine/fileio/backends/__init__.py create mode 100644 mmengine/fileio/backends/base.py create mode 100644 mmengine/fileio/backends/http_backend.py create mode 100644 mmengine/fileio/backends/lmdb_backend.py create mode 100644 mmengine/fileio/backends/local_backend.py create mode 100644 mmengine/fileio/backends/memcached_backend.py create mode 100644 mmengine/fileio/backends/petrel_backend.py create mode 100644 mmengine/fileio/backends/registry_utils.py create mode 100644 mmengine/fileio/handlers/registry_utils.py create mode 100644 tests/test_fileio/test_backends/test_backend_utils.py create mode 100644 tests/test_fileio/test_backends/test_base_storage_backend.py create mode 100644 tests/test_fileio/test_backends/test_http_backend.py create mode 100644 tests/test_fileio/test_backends/test_lmdb_backend.py create mode 100644 tests/test_fileio/test_backends/test_local_backend.py create mode 100644 tests/test_fileio/test_backends/test_memcached_backend.py create mode 100644 tests/test_fileio/test_backends/test_petrel_backend.py create mode 100644 tests/test_fileio/test_io.py diff --git a/docs/en/api/fileio.rst b/docs/en/api/fileio.rst index bf27fc0e..1b8c14b4 100644 --- a/docs/en/api/fileio.rst +++ b/docs/en/api/fileio.rst @@ -11,7 +11,7 @@ mmengine.fileio .. currentmodule:: mmengine.fileio -File Client +File Backend ---------------- .. autosummary:: @@ -22,11 +22,18 @@ File Client BaseStorageBackend FileClient HardDiskBackend + LocalBackend HTTPBackend LmdbBackend MemcachedBackend PetrelBackend +.. autosummary:: + :toctree: generated + :nosignatures: + + register_backend + File Handler ---------------- @@ -40,6 +47,12 @@ File Handler PickleHandler YamlHandler +.. autosummary:: + :toctree: generated + :nosignatures: + + register_handler + File IO ---------------- @@ -49,7 +62,27 @@ File IO dump load - register_handler + copy_if_symlink_fails + copyfile + copyfile_from_local + copyfile_to_local + copytree + copytree_from_local + copytree_to_local + exists + generate_presigned_url + get + get_file_backend + get_local_path + get_text + isdir + isfile + join_path + list_dir_or_file + put + put_text + remove + rmtree Parse File ---------------- diff --git a/docs/zh_cn/api/fileio.rst b/docs/zh_cn/api/fileio.rst index bf27fc0e..1b8c14b4 100644 --- a/docs/zh_cn/api/fileio.rst +++ b/docs/zh_cn/api/fileio.rst @@ -11,7 +11,7 @@ mmengine.fileio .. currentmodule:: mmengine.fileio -File Client +File Backend ---------------- .. autosummary:: @@ -22,11 +22,18 @@ File Client BaseStorageBackend FileClient HardDiskBackend + LocalBackend HTTPBackend LmdbBackend MemcachedBackend PetrelBackend +.. autosummary:: + :toctree: generated + :nosignatures: + + register_backend + File Handler ---------------- @@ -40,6 +47,12 @@ File Handler PickleHandler YamlHandler +.. autosummary:: + :toctree: generated + :nosignatures: + + register_handler + File IO ---------------- @@ -49,7 +62,27 @@ File IO dump load - register_handler + copy_if_symlink_fails + copyfile + copyfile_from_local + copyfile_to_local + copytree + copytree_from_local + copytree_to_local + exists + generate_presigned_url + get + get_file_backend + get_local_path + get_text + isdir + isfile + join_path + list_dir_or_file + put + put_text + remove + rmtree Parse File ---------------- diff --git a/mmengine/fileio/__init__.py b/mmengine/fileio/__init__.py index bea65832..81adcd4c 100644 --- a/mmengine/fileio/__init__.py +++ b/mmengine/fileio/__init__.py @@ -1,14 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .file_client import (BaseStorageBackend, FileClient, HardDiskBackend, - HTTPBackend, LmdbBackend, MemcachedBackend, - PetrelBackend) -from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler -from .io import dump, load, register_handler +from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend, + LocalBackend, MemcachedBackend, PetrelBackend, + register_backend) +from .file_client import FileClient, HardDiskBackend +from .handlers import (BaseFileHandler, JsonHandler, PickleHandler, + YamlHandler, register_handler) +from .io import (copy_if_symlink_fails, copyfile, copyfile_from_local, + copyfile_to_local, copytree, copytree_from_local, + copytree_to_local, dump, exists, generate_presigned_url, get, + get_file_backend, get_local_path, get_text, isdir, isfile, + join_path, list_dir_or_file, load, put, put_text, remove, + rmtree) from .parse import dict_from_file, list_from_file __all__ = [ 'BaseStorageBackend', 'FileClient', 'PetrelBackend', 'MemcachedBackend', - 'LmdbBackend', 'HardDiskBackend', 'HTTPBackend', 'load', 'dump', - 'register_handler', 'BaseFileHandler', 'JsonHandler', 'PickleHandler', - 'YamlHandler', 'list_from_file', 'dict_from_file' + 'LmdbBackend', 'HardDiskBackend', 'LocalBackend', 'HTTPBackend', + 'copy_if_symlink_fails', 'copyfile', 'copyfile_from_local', + 'copyfile_to_local', 'copytree', 'copytree_from_local', + 'copytree_to_local', 'exists', 'generate_presigned_url', 'get', + 'get_file_backend', 'get_local_path', 'get_text', 'isdir', 'isfile', + 'join_path', 'list_dir_or_file', 'put', 'put_text', 'remove', 'rmtree', + 'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler', + 'PickleHandler', 'YamlHandler', 'list_from_file', 'dict_from_file', + 'register_backend' ] diff --git a/mmengine/fileio/backends/__init__.py b/mmengine/fileio/backends/__init__.py new file mode 100644 index 00000000..fa000897 --- /dev/null +++ b/mmengine/fileio/backends/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseStorageBackend +from .http_backend import HTTPBackend +from .lmdb_backend import LmdbBackend +from .local_backend import LocalBackend +from .memcached_backend import MemcachedBackend +from .petrel_backend import PetrelBackend +from .registry_utils import backends, prefix_to_backends, register_backend + +__all__ = [ + 'BaseStorageBackend', 'LocalBackend', 'HTTPBackend', 'LmdbBackend', + 'MemcachedBackend', 'PetrelBackend', 'register_backend', 'backends', + 'prefix_to_backends' +] diff --git a/mmengine/fileio/backends/base.py b/mmengine/fileio/backends/base.py new file mode 100644 index 00000000..4ebabd3b --- /dev/null +++ b/mmengine/fileio/backends/base.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: :meth:`get()` and + :meth:`get_text()`. + + - :meth:`get()` reads the file as a byte stream. + - :meth:`get_text()` reads the file as texts. + """ + + # a flag to indicate whether the backend can create a symlink for a file + # This attribute will be deprecated in future. + _allow_symlink = False + + @property + def allow_symlink(self): + warnings.warn('allow_symlink will be deprecated in future', + DeprecationWarning) + return self._allow_symlink + + @property + def name(self): + return self.__class__.__name__ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass diff --git a/mmengine/fileio/backends/http_backend.py b/mmengine/fileio/backends/http_backend.py new file mode 100644 index 00000000..b3e65bbd --- /dev/null +++ b/mmengine/fileio/backends/http_backend.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Union +from urllib.request import urlopen + +from .base import BaseStorageBackend + + +class HTTPBackend(BaseStorageBackend): + """HTTP and HTTPS storage bachend.""" + + def get(self, filepath: str) -> bytes: + """Read bytes from a given ``filepath``. + + Args: + filepath (str): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get('http://path/of/file') + b'hello world' + """ + return urlopen(filepath).read() + + def get_text(self, filepath, encoding='utf-8') -> str: + """Read text from a given ``filepath``. + + Args: + filepath (str): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get_text('http://path/of/file') + 'hello world' + """ + return urlopen(filepath).read().decode(encoding) + + @contextmanager + def get_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath`` to a local temporary directory, + and return the temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Yields: + Iterable[str]: Only yield one temporary path. + + Examples: + >>> backend = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with backend.get_local_path('http://path/of/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) diff --git a/mmengine/fileio/backends/lmdb_backend.py b/mmengine/fileio/backends/lmdb_backend.py new file mode 100644 index 00000000..eb47923e --- /dev/null +++ b/mmengine/fileio/backends/lmdb_backend.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Union + +from .base import BaseStorageBackend + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_path (str): Lmdb database path. + readonly (bool): Lmdb environment parameter. If True, disallow any + write operations. Defaults to True. + lock (bool): Lmdb environment parameter. If False, when concurrent + access occurs, do not lock the database. Defaults to False. + readahead (bool): Lmdb environment parameter. If False, disable the OS + filesystem readahead mechanism, which may improve random read + performance when a database is larger than RAM. Defaults to False. + **kwargs: Keyword arguments passed to `lmdb.open`. + + Attributes: + db_path (str): Lmdb database path. + """ + + def __init__(self, + db_path, + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb # noqa: F401 + except ImportError: + raise ImportError( + 'Please run "pip install lmdb" to enable LmdbBackend.') + + self.db_path = str(db_path) + self.readonly = readonly + self.lock = lock + self.readahead = readahead + self.kwargs = kwargs + self._client = None + + def get(self, filepath: Union[str, Path]) -> bytes: + """Get values according to the filepath. + + Args: + filepath (str or Path): Here, filepath is the lmdb key. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = LmdbBackend('path/to/lmdb') + >>> backend.get('key') + b'hello world' + """ + if self._client is None: + self._client = self._get_client() + + filepath = str(filepath) + with self._client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + def _get_client(self): + import lmdb + + return lmdb.open( + self.db_path, + readonly=self.readonly, + lock=self.lock, + readahead=self.readahead, + **self.kwargs) + + def __del__(self): + if self._client is not None: + self._client.close() diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py new file mode 100644 index 00000000..0c6c7774 --- /dev/null +++ b/mmengine/fileio/backends/local_backend.py @@ -0,0 +1,543 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import shutil +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Iterator, Optional, Tuple, Union + +import mmengine +from .base import BaseStorageBackend + + +class LocalBackend(BaseStorageBackend): + """Raw local storage backend.""" + + _allow_symlink = True + + def get(self, filepath: Union[str, Path]) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get(filepath) + b'hello world' + """ + with open(filepath, 'rb') as f: + value = f.read() + return value + + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get_text(filepath) + 'hello world' + """ + with open(filepath, encoding=encoding) as f: + text = f.read() + return text + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put(b'hello world', filepath) + """ + mmengine.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'wb') as f: + f.write(obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put_text('hello world', filepath) + """ + mmengine.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.exists(filepath) + True + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/dir' + >>> backend.isdir(filepath) + True + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.isfile(filepath) + True + """ + return osp.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + + Examples: + >>> backend = LocalBackend() + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> backend.join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + # TODO, if filepath or filepaths are Path, should return Path + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, Path], + ) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing. + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> backend = LocalBackend() + >>> with backend.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + """ + yield filepath + + def copyfile( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> backend.copyfile(src, dst) + '/path1/of/dir/file' + """ + return shutil.copy(src, dst) + + def copytree( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + TODO: Whether to support dirs_exist_ok parameter. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree(src, dst) + '/path/of/dir2' + """ + return shutil.copytree(src, dst) + + def copyfile_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a local file src to dst and return the destination file. Same + as :meth:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_from_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_from_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. Same as + :meth:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def copyfile_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy the file src to local dst and return the destination file. Same + as :meth:`copyfile`. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_to_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_to_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + + Raises: + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.remove(filepath) + """ + if not self.exists(filepath): + raise FileNotFoundError(f'filepath {filepath} does not exist') + + if self.isdir(filepath): + raise IsADirectoryError('filepath should be a file') + + os.remove(filepath) + + def rmtree(self, dir_path: Union[str, Path]) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + + Examples: + >>> dir_path = '/path/of/dir' + >>> backend.rmtree(dir_path) + """ + shutil.rmtree(dir_path) + + def copy_if_symlink_fails( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directly copy src + to dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + + Returns: + bool: Return True if successfully create a symbolic link pointing + to src. Otherwise, return False. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> backend.copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> backend.copy_if_symlink_fails(src, dst) + True + """ + try: + os.symlink(src, dst) + return True + except Exception: + if self.isfile(src): + self.copyfile(src, dst) + else: + self.copytree(src, dst) + return False + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> backend = LocalBackend() + >>> dir_path = '/path/of/dir' + >>> # list those files and directories in current directory + >>> for file_path in backend.list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ # noqa: E501 + if list_dir and suffix is not None: + raise TypeError('`suffix` should be None when `list_dir` is True') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, + list_file, suffix, + recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) diff --git a/mmengine/fileio/backends/memcached_backend.py b/mmengine/fileio/backends/memcached_backend.py new file mode 100644 index 00000000..2458e7c6 --- /dev/null +++ b/mmengine/fileio/backends/memcached_backend.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Union + +from .base import BaseStorageBackend + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str, optional): Additional path to be appended to `sys.path`. + Defaults to None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath: Union[str, Path]): + """Get values according to the filepath. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> server_list_cfg = '/path/of/server_list.conf' + >>> client_cfg = '/path/of/mc.conf' + >>> backend = MemcachedBackend(server_list_cfg, client_cfg) + >>> backend.get('/path/of/file') + b'hello world' + """ + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py new file mode 100644 index 00000000..bfb23bd5 --- /dev/null +++ b/mmengine/fileio/backends/petrel_backend.py @@ -0,0 +1,768 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import re +import tempfile +from contextlib import contextmanager +from pathlib import Path +from shutil import SameFileError +from typing import Generator, Iterator, Optional, Tuple, Union + +import mmengine +from mmengine.utils import has_method +from .base import BaseStorageBackend + + +class PetrelBackend(BaseStorageBackend): + """Petrel storage backend (for internal usage). + + PetrelBackend supports reading and writing data to multiple clusters. + If the file path contains the cluster name, PetrelBackend will read data + from specified cluster or write data to it. Otherwise, PetrelBackend will + access the default cluster. + + Args: + path_mapping (dict, optional): Path mapping dict from local path to + Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in + ``filepath`` will be replaced by ``dst``. Defaults to None. + enable_mc (bool, optional): Whether to enable memcached support. + Defaults to True. + + Examples: + >>> backend = PetrelBackend() + >>> filepath1 = 'petrel://path/of/file' + >>> filepath2 = 'cluster-name:petrel://path/of/file' + >>> backend.get(filepath1) # get data from default cluster + >>> client.get(filepath2) # get data from 'cluster-name' cluster + """ + + def __init__(self, + path_mapping: Optional[dict] = None, + enable_mc: bool = True): + try: + from petrel_client import client + except ImportError: + raise ImportError('Please install petrel_client to enable ' + 'PetrelBackend.') + + self._client = client.Client(enable_mc=enable_mc) + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + def _map_path(self, filepath: Union[str, Path]) -> str: + """Map ``filepath`` to a string path whose prefix will be replaced by + :attr:`self.path_mapping`. + + Args: + filepath (str or Path): Path to be mapped. + """ + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v, 1) + return filepath + + def _format_path(self, filepath: str) -> str: + """Convert a ``filepath`` to standard format of petrel oss. + + If the ``filepath`` is concatenated by ``os.path.join``, in a Windows + environment, the ``filepath`` will be the format of + 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the + above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. + + Args: + filepath (str): Path to be formatted. + """ + return re.sub(r'\\+', '/', filepath) + + def _replace_prefix(self, filepath: Union[str, Path]) -> str: + filepath = str(filepath) + return filepath.replace('petrel://', 's3://') + + def get(self, filepath: Union[str, Path]) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Return bytes read from filepath. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.get(filepath) + b'hello world' + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + value = self._client.Get(filepath) + return value + + def get_text( + self, + filepath: Union[str, Path], + encoding: str = 'utf-8', + ) -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.get_text(filepath) + 'hello world' + """ + return str(self.get(filepath), encoding=encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write bytes to a given ``filepath``. + + Args: + obj (bytes): Data to be saved. + filepath (str or Path): Path to write data. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.put(b'hello world', filepath) + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + self._client.put(filepath, obj) + + def put_text( + self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8', + ) -> None: + """Write text to a given ``filepath``. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to encode the ``obj``. + Defaults to 'utf-8'. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.put_text('hello world', filepath) + """ + self.put(bytes(obj, encoding=encoding), filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.exists(filepath) + True + """ + if not (has_method(self._client, 'contains') + and has_method(self._client, 'isdir')): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `contains` and `isdir` methods, please use a higher' + 'version or dev branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + return self._client.contains(filepath) or self._client.isdir(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/dir' + >>> backend.isdir(filepath) + True + """ + if not has_method(self._client, 'isdir'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `isdir` method, please use a higher version or dev' + ' branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + return self._client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.isfile(filepath) + True + """ + if not has_method(self._client, 'contains'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `contains` method, please use a higher version or ' + 'dev branch instead.') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + return self._client.contains(filepath) + + def join_path( + self, + filepath: Union[str, Path], + *filepaths: Union[str, Path], + ) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result after concatenation. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.join_path(filepath, 'another/path') + 'petrel://path/of/file/another/path' + >>> backend.join_path(filepath, '/another/path') + 'petrel://path/of/file/another/path' + """ + filepath = self._format_path(self._map_path(filepath)) + if filepath.endswith('/'): + filepath = filepath[:-1] + formatted_paths = [filepath] + for path in filepaths: + formatted_path = self._format_path(self._map_path(path)) + formatted_paths.append(formatted_path.lstrip('/')) + + return '/'.join(formatted_paths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, Path], + ) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath`` to a local temporary directory, + and return the temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str or Path): Download a file from ``filepath``. + + Yields: + Iterable[str]: Only yield one temporary path. + + Examples: + >>> backend = PetrelBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> filepath = 'petrel://path/of/file' + >>> with backend.get_local_path(filepath) as path: + ... # do something here + """ + assert self.isfile(filepath) + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def copyfile( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = PetrelBackend() + >>> # dst is a file + >>> src = 'petrel://path/of/file' + >>> dst = 'petrel://path/of/file1' + >>> backend.copyfile(src, dst) + 'petrel://path/of/file1' + + >>> # dst is a directory + >>> dst = 'petrel://path/of/dir' + >>> backend.copyfile(src, dst) + 'petrel://path/of/dir/file' + """ + src = self._format_path(self._map_path(src)) + dst = self._format_path(self._map_path(dst)) + if self.isdir(dst): + dst = self.join_path(dst, src.split('/')[-1]) + + if src == dst: + raise SameFileError('src and dst should not be same') + + self.put(self.get(src), dst) + return dst + + def copytree( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + src and dst should have the same prefix. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = PetrelBackend() + >>> src = 'petrel://path/of/dir' + >>> dst = 'petrel://path/of/dir1' + >>> backend.copytree(src, dst) + 'petrel://path/of/dir1' + """ + src = self._format_path(self._map_path(src)) + dst = self._format_path(self._map_path(dst)) + + if self.exists(dst): + raise FileExistsError('dst should not exist') + + for path in self.list_dir_or_file(src, list_dir=False, recursive=True): + src_path = self.join_path(src, path) + dst_path = self.join_path(dst, path) + self.put(self.get(src_path), dst_path) + + return dst + + def copyfile_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Upload a local file src to dst and return the destination file. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = PetrelBackend() + >>> # dst is a file + >>> src = 'path/of/your/file' + >>> dst = 'petrel://path/of/file1' + >>> backend.copyfile_from_local(src, dst) + 'petrel://path/of/file1' + + >>> # dst is a directory + >>> dst = 'petrel://path/of/dir' + >>> backend.copyfile_from_local(src, dst) + 'petrel://path/of/dir/file' + """ + dst = self._format_path(self._map_path(dst)) + if self.isdir(dst): + dst = self.join_path(dst, osp.basename(src)) + + with open(src, 'rb') as f: + self.put(f.read(), dst) + + return dst + + def copytree_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = PetrelBackend() + >>> src = 'path/of/your/dir' + >>> dst = 'petrel://path/of/dir1' + >>> backend.copytree_from_local(src, dst) + 'petrel://path/of/dir1' + """ + dst = self._format_path(self._map_path(dst)) + if self.exists(dst): + raise FileExistsError('dst should not exist') + + src = str(src) + + for cur_dir, _, files in os.walk(src): + for f in files: + src_path = osp.join(cur_dir, f) + dst_path = self.join_path(dst, src_path.replace(src, '')) + self.copyfile_from_local(src_path, dst_path) + + return dst + + def copyfile_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> Union[str, Path]: + """Copy the file src to local dst and return the destination file. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = PetrelBackend() + >>> # dst is a file + >>> src = 'petrel://path/of/file' + >>> dst = 'path/of/your/file' + >>> backend.copyfile_to_local(src, dst) + 'path/of/your/file' + + >>> # dst is a directory + >>> dst = 'path/of/your/dir' + >>> backend.copyfile_to_local(src, dst) + 'path/of/your/dir/file' + """ + if osp.isdir(dst): + basename = osp.basename(src) + if isinstance(dst, str): + dst = osp.join(dst, basename) + else: + assert isinstance(dst, Path) + dst = dst / basename + + with open(dst, 'wb') as f: + f.write(self.get(src)) + + return dst + + def copytree_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> backend = PetrelBackend() + >>> src = 'petrel://path/of/dir' + >>> dst = 'path/of/your/dir' + >>> backend.copytree_to_local(src, dst) + 'path/of/your/dir' + """ + for path in self.list_dir_or_file(src, list_dir=False, recursive=True): + dst_path = osp.join(dst, path) + mmengine.mkdir_or_exist(osp.dirname(dst_path)) + with open(dst_path, 'wb') as f: + f.write(self.get(self.join_path(src, path))) + + return dst + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + + Raises: + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + + Examples: + >>> backend = PetrelBackend() + >>> filepath = 'petrel://path/of/file' + >>> backend.remove(filepath) + """ + if not has_method(self._client, 'delete'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `delete` method, please use a higher version or dev ' + 'branch instead.') + + if not self.exists(filepath): + raise FileNotFoundError(f'filepath {filepath} does not exist') + + if self.isdir(filepath): + raise IsADirectoryError('filepath should be a file') + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + filepath = self._replace_prefix(filepath) + self._client.delete(filepath) + + def rmtree(self, dir_path: Union[str, Path]) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + + Examples: + >>> backend = PetrelBackend() + >>> dir_path = 'petrel://path/of/dir' + >>> backend.rmtree(dir_path) + """ + for path in self.list_dir_or_file( + dir_path, list_dir=False, recursive=True): + filepath = self.join_path(dir_path, path) + self.remove(filepath) + + def copy_if_symlink_fails( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> bool: + """Create a symbolic link pointing to src named dst. + + Directly copy src to dst because PetrelBacekend does not support create + a symbolic link. + + Args: + src (str or Path): A file or directory to be copied. + dst (str or Path): Copy a file or directory to dst. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + bool: Return False because PetrelBackend does not support create + a symbolic link. + + Examples: + >>> backend = PetrelBackend() + >>> src = 'petrel://path/of/file' + >>> dst = 'petrel://path/of/your/file' + >>> backend.copy_if_symlink_fails(src, dst) + False + >>> src = 'petrel://path/of/dir' + >>> dst = 'petrel://path/of/your/dir' + >>> backend.copy_if_symlink_fails(src, dst) + False + """ + if self.isfile(src): + self.copyfile(src, dst) + else: + self.copytree(src, dst) + return False + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + Petrel has no concept of directories but it simulates the directory + hierarchy in the filesystem through public prefixes. In addition, + if the returned path ends with '/', it means the path is a public + prefix which is a logical directory. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + In addition, the returned path of directory will not contains the + suffix '/' which is consistent with other backends. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the + directory. Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> backend = PetrelBackend() + >>> dir_path = 'petrel://path/of/dir' + >>> # list those files and directories in current directory + >>> for file_path in backend.list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ # noqa: E501 + if not has_method(self._client, 'list'): + raise NotImplementedError( + 'Current version of Petrel Python SDK has not supported ' + 'the `list` method, please use a higher version or dev' + ' branch instead.') + + dir_path = self._map_path(dir_path) + dir_path = self._format_path(dir_path) + dir_path = self._replace_prefix(dir_path) + if list_dir and suffix is not None: + raise TypeError( + '`list_dir` should be False when `suffix` is not None') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + # Petrel's simulated directory hierarchy assumes that directory paths + # should end with `/` + if not dir_path.endswith('/'): + dir_path += '/' + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for path in self._client.list(dir_path): + # the `self.isdir` is not used here to determine whether path + # is a directory, because `self.isdir` relies on + # `self._client.list` + if path.endswith('/'): # a directory path + next_dir_path = self.join_path(dir_path, path) + if list_dir: + # get the relative path and exclude the last + # character '/' + rel_dir = next_dir_path[len(root):-1] + yield rel_dir + if recursive: + yield from _list_dir_or_file(next_dir_path, list_dir, + list_file, suffix, + recursive) + else: # a file path + absolute_path = self.join_path(dir_path, path) + rel_path = absolute_path[len(root):] + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + def generate_presigned_url(self, + url: str, + client_method: str = 'get_object', + expires_in: int = 3600) -> str: + """Generate the presigned url of video stream which can be passed to + mmcv.VideoReader. Now only work on Petrel backend. + + Note: + Now only work on Petrel backend. + + Args: + url (str): Url of video stream. + client_method (str): Method of client, 'get_object' or + 'put_object'. Default: 'get_object'. + expires_in (int): expires, in seconds. Default: 3600. + + Returns: + str: Generated presigned url. + """ + return self._client.generate_presigned_url(url, client_method, + expires_in) diff --git a/mmengine/fileio/backends/registry_utils.py b/mmengine/fileio/backends/registry_utils.py new file mode 100644 index 00000000..4578a4ca --- /dev/null +++ b/mmengine/fileio/backends/registry_utils.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from typing import Optional, Type, Union + +from .base import BaseStorageBackend +from .http_backend import HTTPBackend +from .lmdb_backend import LmdbBackend +from .local_backend import LocalBackend +from .memcached_backend import MemcachedBackend +from .petrel_backend import PetrelBackend + +backends: dict = {} +prefix_to_backends: dict = {} + + +def _register_backend(name: str, + backend: Type[BaseStorageBackend], + force: bool = False, + prefixes: Union[str, list, tuple, None] = None): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (BaseStorageBackend): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + """ + global backends, prefix_to_backends + + if not isinstance(name, str): + raise TypeError('the backend name should be a string, ' + f'but got {type(name)}') + + if not inspect.isclass(backend): + raise TypeError(f'backend should be a class, but got {type(backend)}') + if not issubclass(backend, BaseStorageBackend): + raise TypeError( + f'backend {backend} is not a subclass of BaseStorageBackend') + + if name in backends and not force: + raise ValueError(f'{name} is already registered as a storage backend, ' + 'add "force=True" if you want to override it') + backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + + for prefix in prefixes: + if prefix in prefix_to_backends and not force: + raise ValueError( + f'{prefix} is already registered as a storage backend,' + ' add "force=True" if you want to override it') + + prefix_to_backends[prefix] = backend + + +def register_backend(name: str, + backend: Optional[Type[BaseStorageBackend]] = None, + force: bool = False, + prefixes: Union[str, list, tuple, None] = None): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + + This method can be used as a normal method or a decorator. + + Examples: + + >>> class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + >>> register_backend('new', NewBackend) + + >>> @register_backend('new') + ... class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + """ + if backend is not None: + _register_backend(name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + _register_backend(name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + +register_backend('local', LocalBackend, prefixes='') +register_backend('memcached', MemcachedBackend) +register_backend('lmdb', LmdbBackend) +# To avoid breaking backward Compatibility, 's3' is also used as a +# prefix for PetrelBackend +register_backend('petrel', PetrelBackend, prefixes=['petrel', 's3']) +register_backend('http', HTTPBackend, prefixes=['http', 'https']) diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index a371a186..7f4a6716 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -1,709 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect -import os -import os.path as osp -import re -import tempfile -from abc import ABCMeta, abstractmethod +import warnings from contextlib import contextmanager from pathlib import Path from typing import Any, Generator, Iterator, Optional, Tuple, Union -from urllib.request import urlopen -from mmengine.utils import has_method, is_filepath, mkdir_or_exist +from mmengine.utils import is_filepath +from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend, + LocalBackend, MemcachedBackend, PetrelBackend) -class BaseStorageBackend(metaclass=ABCMeta): - """Abstract class of storage backends. - - All backends need to implement two apis: ``get()`` and ``get_text()``. - ``get()`` reads the file as a byte stream and ``get_text()`` reads the file - as texts. - """ +class HardDiskBackend(LocalBackend): + """Raw hard disks storage backend.""" - # a flag to indicate whether the backend can create a symlink for a file - _allow_symlink = False + def __init__(self) -> None: + warnings.warn( + '"HardDiskBackend" is the alias of "LocalBackend" ' + 'and the former will be deprecated in future.', DeprecationWarning) @property def name(self): return self.__class__.__name__ - @property - def allow_symlink(self): - return self._allow_symlink - - @abstractmethod - def get(self, filepath): - pass - - @abstractmethod - def get_text(self, filepath): - pass - - -class PetrelBackend(BaseStorageBackend): - """Petrel storage backend (for internal usage). - - PetrelBackend supports reading and writing data to multiple clusters. - If the file path contains the cluster name, PetrelBackend will read data - from specified cluster or write data to it. Otherwise, PetrelBackend will - access the default cluster. - - Args: - path_mapping (dict, optional): Path mapping dict from local path to - Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in - ``filepath`` will be replaced by ``dst``. Default: None. - enable_mc (bool, optional): Whether to enable memcached support. - Default: True. - - Examples: - >>> filepath1 = 's3://path/of/file' - >>> filepath2 = 'cluster-name:s3://path/of/file' - >>> client = PetrelBackend() - >>> client.get(filepath1) # get data from default cluster - >>> client.get(filepath2) # get data from 'cluster-name' cluster - """ - - def __init__(self, - path_mapping: Optional[dict] = None, - enable_mc: bool = True): - try: - from petrel_client import client - except ImportError: - raise ImportError('Please install petrel_client to enable ' - 'PetrelBackend.') - - self._client = client.Client(enable_mc=enable_mc) - assert isinstance(path_mapping, dict) or path_mapping is None - self.path_mapping = path_mapping - - def _map_path(self, filepath: Union[str, Path]) -> str: - """Map ``filepath`` to a string path whose prefix will be replaced by - :attr:`self.path_mapping`. - - Args: - filepath (str): Path to be mapped. - """ - filepath = str(filepath) - if self.path_mapping is not None: - for k, v in self.path_mapping.items(): - filepath = filepath.replace(k, v, 1) - return filepath - - def _format_path(self, filepath: str) -> str: - """Convert a ``filepath`` to standard format of petrel oss. - - If the ``filepath`` is concatenated by ``os.path.join``, in a Windows - environment, the ``filepath`` will be the format of - 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the - above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. - - Args: - filepath (str): Path to be formatted. - """ - return re.sub(r'\\+', '/', filepath) - - def get(self, filepath: Union[str, Path]) -> memoryview: - """Read data from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - memoryview: A memory view of expected bytes object to avoid - copying. The memoryview object can be converted to bytes by - ``value_buf.tobytes()``. - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - value = self._client.Get(filepath) - value_buf = memoryview(value) - return value_buf - - def get_text(self, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> str: - """Read data from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Default: 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - """ - return str(self.get(filepath), encoding=encoding) - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Save data to a given ``filepath``. - - Args: - obj (bytes): Data to be saved. - filepath (str or Path): Path to write data. - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - self._client.put(filepath, obj) - - def put_text(self, - obj: str, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> None: - """Save data to a given ``filepath``. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to encode the ``obj``. - Default: 'utf-8'. - """ - self.put(bytes(obj, encoding=encoding), filepath) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - """ - if not has_method(self._client, 'delete'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `delete` method, please use a higher version or dev' - ' branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - self._client.delete(filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - """ - if not (has_method(self._client, 'contains') - and has_method(self._client, 'isdir')): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `contains` and `isdir` methods, please use a higher' - 'version or dev branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - return self._client.contains(filepath) or self._client.isdir(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - """ - if not has_method(self._client, 'isdir'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `isdir` method, please use a higher version or dev' - ' branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - return self._client.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - """ - if not has_method(self._client, 'contains'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `contains` method, please use a higher version or ' - 'dev branch instead.') - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - return self._client.contains(filepath) - - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: - """Concatenate all file paths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result after concatenation. - """ - filepath = self._format_path(self._map_path(filepath)) - if filepath.endswith('/'): - filepath = filepath[:-1] - formatted_paths = [filepath] - for path in filepaths: - formatted_paths.append(self._format_path(self._map_path(path))) - return '/'.join(formatted_paths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, - Path]) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath`` and return a temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str | Path): Download a file from ``filepath``. - - Examples: - >>> client = PetrelBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> with client.get_local_path('s3://path/of/your/file') as path: - ... # do something here - - Yields: - Iterable[str]: Only yield one temporary path. - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - assert self.isfile(filepath) - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) - - def list_dir_or_file(self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - Petrel has no concept of directories but it simulates the directory - hierarchy in the filesystem through public prefixes. In addition, - if the returned path ends with '/', it means the path is a public - prefix which is a logical directory. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - In addition, the returned path of directory will not contains the - suffix '/' which is consistent with other backends. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Default: True. - list_file (bool): List the path of files. Default: True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Default: None. - recursive (bool): If set to True, recursively scan the - directory. Default: False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - """ - if not has_method(self._client, 'list'): - raise NotImplementedError( - 'Current version of Petrel Python SDK has not supported ' - 'the `list` method, please use a higher version or dev' - ' branch instead.') - - dir_path = self._map_path(dir_path) - dir_path = self._format_path(dir_path) - if list_dir and suffix is not None: - raise TypeError( - '`list_dir` should be False when `suffix` is not None') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('`suffix` must be a string or tuple of strings') - - # Petrel's simulated directory hierarchy assumes that directory paths - # should end with `/` - if not dir_path.endswith('/'): - dir_path += '/' - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive): - for path in self._client.list(dir_path): - # the `self.isdir` is not used here to determine whether path - # is a directory, because `self.isdir` relies on - # `self._client.list` - if path.endswith('/'): # a directory path - next_dir_path = self.join_path(dir_path, path) - if list_dir: - # get the relative path and exclude the last - # character '/' - rel_dir = next_dir_path[len(root):-1] - yield rel_dir - if recursive: - yield from _list_dir_or_file(next_dir_path, list_dir, - list_file, suffix, - recursive) - else: # a file path - absolute_path = self.join_path(dir_path, path) - rel_path = absolute_path[len(root):] - if (suffix is None - or rel_path.endswith(suffix)) and list_file: - yield rel_path - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive) - - -class MemcachedBackend(BaseStorageBackend): - """Memcached storage backend. - - Attributes: - server_list_cfg (str): Config file for memcached server list. - client_cfg (str): Config file for memcached client. - sys_path (str | None): Additional path to be appended to `sys.path`. - Default: None. - """ - - def __init__(self, server_list_cfg, client_cfg, sys_path=None): - if sys_path is not None: - import sys - sys.path.append(sys_path) - try: - import mc - except ImportError: - raise ImportError( - 'Please install memcached to enable MemcachedBackend.') - - self.server_list_cfg = server_list_cfg - self.client_cfg = client_cfg - self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, - self.client_cfg) - # mc.pyvector servers as a point which points to a memory cache - self._mc_buffer = mc.pyvector() - - def get(self, filepath): - filepath = str(filepath) - import mc - self._client.Get(filepath, self._mc_buffer) - value_buf = mc.ConvertBuffer(self._mc_buffer) - return value_buf - - def get_text(self, filepath, encoding=None): - raise NotImplementedError - - -class LmdbBackend(BaseStorageBackend): - """Lmdb storage backend. - - Args: - db_path (str): Lmdb database path. - readonly (bool, optional): Lmdb environment parameter. If True, - disallow any write operations. Default: True. - lock (bool, optional): Lmdb environment parameter. If False, when - concurrent access occurs, do not lock the database. Default: False. - readahead (bool, optional): Lmdb environment parameter. If False, - disable the OS filesystem readahead mechanism, which may improve - random read performance when a database is larger than RAM. - Default: False. - - Attributes: - db_path (str): Lmdb database path. - """ - - def __init__(self, - db_path, - readonly=True, - lock=False, - readahead=False, - **kwargs): - try: - import lmdb # NOQA - except ImportError: - raise ImportError('Please install lmdb to enable LmdbBackend.') - - self.db_path = str(db_path) - self.readonly = readonly - self.lock = lock - self.readahead = readahead - self.kwargs = kwargs - self._client = None - - def get(self, filepath): - """Get values according to the filepath. - - Args: - filepath (str | obj:`Path`): Here, filepath is the lmdb key. - """ - if self._client is None: - self._client = self._get_client() - - with self._client.begin(write=False) as txn: - value_buf = txn.get(str(filepath).encode('utf-8')) - return value_buf - - def get_text(self, filepath, encoding=None): - raise NotImplementedError - - def _get_client(self): - import lmdb - - return lmdb.open( - self.db_path, - readonly=self.readonly, - lock=self.lock, - readahead=self.readahead, - **self.kwargs) - - def __del__(self): - self._client.close() - - -class HardDiskBackend(BaseStorageBackend): - """Raw hard disks storage backend.""" - - _allow_symlink = True - - def get(self, filepath: Union[str, Path]) -> bytes: - """Read data from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Expected bytes object. - """ - with open(filepath, 'rb') as f: - value_buf = f.read() - return value_buf - - def get_text(self, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> str: - """Read data from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Default: 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - """ - with open(filepath, encoding=encoding) as f: - value_buf = f.read() - return value_buf - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write data to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` will create a directory if the directory of ``filepath`` - does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - """ - mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, 'wb') as f: - f.write(obj) - - def put_text(self, - obj: str, - filepath: Union[str, Path], - encoding: str = 'utf-8') -> None: - """Write data to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to open the ``filepath``. - Default: 'utf-8'. - """ - mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, 'w', encoding=encoding) as f: - f.write(obj) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - """ - os.remove(filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - """ - return osp.exists(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - """ - return osp.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - """ - return osp.isfile(filepath) - - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: - """Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of *filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - """ - return osp.join(filepath, *filepaths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, - Path]) -> Generator[Union[str, Path], None, None]: - """Only for unified API and do nothing.""" - yield filepath - - def list_dir_or_file(self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Default: True. - list_file (bool): List the path of files. Default: True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Default: None. - recursive (bool): If set to True, recursively scan the - directory. Default: False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - """ - if list_dir and suffix is not None: - raise TypeError('`suffix` should be None when `list_dir` is True') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('`suffix` must be a string or tuple of strings') - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - if (suffix is None - or rel_path.endswith(suffix)) and list_file: - yield rel_path - elif osp.isdir(entry.path): - if list_dir: - rel_dir = osp.relpath(entry.path, root) - yield rel_dir - if recursive: - yield from _list_dir_or_file(entry.path, list_dir, - list_file, suffix, - recursive) - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, - recursive) - - -class HTTPBackend(BaseStorageBackend): - """HTTP and HTTPS storage bachend.""" - - def get(self, filepath): - value_buf = urlopen(filepath).read() - return value_buf - - def get_text(self, filepath, encoding='utf-8'): - value_buf = urlopen(filepath).read() - return value_buf.decode(encoding) - - @contextmanager - def get_local_path( - self, filepath: str) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath``. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str): Download a file from ``filepath``. - - Examples: - >>> client = HTTPBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> with client.get_local_path('http://path/of/your/file') as path: - ... # do something here - """ - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) - class FileClient: """A general file client to access files in different backends. @@ -719,9 +37,13 @@ class FileClient: avoid repeated object creation. If the arguments are the same, the same object will be returned. + Warning: + `FileClient` will be deprecated in future. Please use io functions + in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io + Args: backend (str, optional): The storage backend type. Options are "disk", - "ceph", "memcached", "lmdb", "http" and "petrel". Default: None. + "memcached", "lmdb", "http" and "petrel". Default: None. prefix (str, optional): The prefix of the registered storage backend. Options are "s3", "http", "https". Default: None. @@ -751,6 +73,7 @@ class FileClient: _prefix_to_backends: dict = { 's3': PetrelBackend, + 'petrel': PetrelBackend, 'http': HTTPBackend, 'https': HTTPBackend, } @@ -760,6 +83,11 @@ class FileClient: client: Any def __new__(cls, backend=None, prefix=None, **kwargs): + warnings.warn( + '"FileClient" will be deprecated in future. Please use io ' + 'functions in ' + 'https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io', # noqa: E501 + DeprecationWarning) if backend is None and prefix is None: backend = 'disk' if backend is not None and backend not in cls._backends: diff --git a/mmengine/fileio/handlers/__init__.py b/mmengine/fileio/handlers/__init__.py index aa24d919..391a60c3 100644 --- a/mmengine/fileio/handlers/__init__.py +++ b/mmengine/fileio/handlers/__init__.py @@ -2,6 +2,10 @@ from .base import BaseFileHandler from .json_handler import JsonHandler from .pickle_handler import PickleHandler +from .registry_utils import file_handlers, register_handler from .yaml_handler import YamlHandler -__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler'] +__all__ = [ + 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', + 'register_handler', 'file_handlers' +] diff --git a/mmengine/fileio/handlers/registry_utils.py b/mmengine/fileio/handlers/registry_utils.py new file mode 100644 index 00000000..106fc881 --- /dev/null +++ b/mmengine/fileio/handlers/registry_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import is_list_of +from .base import BaseFileHandler +from .json_handler import JsonHandler +from .pickle_handler import PickleHandler +from .yaml_handler import YamlHandler + +file_handlers = { + 'json': JsonHandler(), + 'yaml': YamlHandler(), + 'yml': YamlHandler(), + 'pickle': PickleHandler(), + 'pkl': PickleHandler(), +} + + +def _register_handler(handler, file_formats): + """Register a handler for some file extensions. + + Args: + handler (:obj:`BaseFileHandler`): Handler to be registered. + file_formats (str or list[str]): File formats to be handled by this + handler. + """ + if not isinstance(handler, BaseFileHandler): + raise TypeError( + f'handler must be a child of BaseFileHandler, not {type(handler)}') + if isinstance(file_formats, str): + file_formats = [file_formats] + if not is_list_of(file_formats, str): + raise TypeError('file_formats must be a str or a list of str') + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats, **kwargs): + + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index 7374aa15..62f9a4ef 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -1,21 +1,797 @@ # Copyright (c) OpenMMLab. All rights reserved. +"""This module provides unified file I/O related functions, which support +operating I/O with different file backends based on the specified filepath or +backend_args. + +MMEngine currently supports five file backends: + +- LocalBackend +- PetrelBackend +- HTTPBackend +- LmdbBackend +- MemcacheBackend + +Note that this module provide a union of all of the above file backends so +NotImplementedError will be raised if the interface in the file backend is not +implemented. + +There are two ways to call a method of a file backend: + +- Initialize a file backend with ``get_file_backend`` and call its methods. +- Directory call unified I/O functions, which will call ``get_file_backend`` + first and then call the corresponding backend method. + +Examples: + >>> # Initialize a file backend and call its methods + >>> import mmengine.fileio as fileio + >>> backend = fileio.get_file_backend(backend_args={'backend': 'petrel'}) + >>> backend.get('s3://path/of/your/file') + + >>> # Directory call unified I/O functions + >>> fileio.get('s3://path/of/your/file') +""" +import json +import warnings +from contextlib import contextmanager from io import BytesIO, StringIO from pathlib import Path +from typing import Generator, Iterator, Optional, Tuple, Union -from mmengine.utils import is_list_of, is_str +from mmengine.utils import is_filepath, is_str +from .backends import backends, prefix_to_backends from .file_client import FileClient -from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler +# file_handlers and register_handler had been moved to +# mmengine/fileio/handlers/registry_utis. Import them +# in this file to keep backward compatibility. +from .handlers import file_handlers, register_handler # noqa: F401 + +backend_instances: dict = {} + + +def _parse_uri_prefix(uri: Union[str, Path]) -> str: + """Parse the prefix of uri. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> _parse_uri_prefix('/home/path/of/your/file') + '' + >>> _parse_uri_prefix('s3://path/of/your/file') + 's3' + >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') + 's3' + + Returns: + str: Return the prefix of uri if the uri contains '://'. Otherwise, + return ''. + """ + assert is_filepath(uri) + uri = str(uri) + # if uri does not contains '://', the uri will be handled by + # LocalBackend by default + if '://' not in uri: + return '' + else: + prefix, _ = uri.split('://') + # In the case of PetrelBackend, the prefix may contain the cluster + # name like clusterName:s3://path/of/your/file + if ':' in prefix: + _, prefix = prefix.split(':') + return prefix + + +def _get_file_backend(prefix: str, backend_args: dict): + """Return a file backend based on the prefix or backend_args. + + Args: + prefix (str): Prefix of uri. + backend_args (dict): Arguments to instantiate the corresponding + backend. + """ + # backend name has a higher priority + if 'backend' in backend_args: + backend_name = backend_args.pop('backend') + backend = backends[backend_name](**backend_args) + else: + backend = prefix_to_backends[prefix](**backend_args) + return backend + + +def get_file_backend( + uri: Union[str, Path, None] = None, + *, + backend_args: Optional[dict] = None, + enable_singleton: bool = False, +): + """Return a file backend based on the prefix of uri or backend_args. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + enable_singleton (bool): Whether to enable the singleton pattern. + If it is True, the backend created will be reused if the + signature is same with the previous one. Defaults to False. + + Returns: + BaseStorageBackend: Instantiated Backend object. + + Examples: + >>> # get file backend based on the prefix of uri + >>> uri = 's3://path/of/your/file' + >>> backend = get_file_backend(uri) + >>> # get file backend based on the backend_args + >>> backend = get_file_backend(backend_args={'backend': 'petrel'}) + >>> # backend name has a higher priority if 'backend' in backend_args + >>> backend = get_file_backend(uri, backend_args={'backend': 'petrel'}) + """ + global backend_instances + + if backend_args is None: + backend_args = {} + + if uri is None and 'backend' not in backend_args: + raise ValueError( + 'uri should not be None when "backend" does not exist in ' + 'backend_args') + + if uri is not None: + prefix = _parse_uri_prefix(uri) + else: + prefix = '' + + if enable_singleton: + # TODO: whether to pass sort_key to json.dumps + unique_key = f'{prefix}:{json.dumps(backend_args)}' + if unique_key in backend_instances: + return backend_instances[unique_key] + + backend = _get_file_backend(prefix, backend_args) + backend_instances[unique_key] = backend + return backend + else: + backend = _get_file_backend(prefix, backend_args) + return backend + + +def get( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. -file_handlers = { - 'json': JsonHandler(), - 'yaml': YamlHandler(), - 'yml': YamlHandler(), - 'pickle': PickleHandler(), - 'pkl': PickleHandler() -} + Returns: + bytes: Expected bytes object. + Examples: + >>> filepath = '/path/of/file' + >>> get(filepath) + b'hello world' + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.get(filepath) -def load(file, file_format=None, file_client_args=None, **kwargs): + +def get_text( + filepath: Union[str, Path], + encoding='utf-8', + backend_args: Optional[dict] = None, +) -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> filepath = '/path/of/file' + >>> get_text(filepath) + 'hello world' + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.get_text(filepath, encoding) + + +def put( + obj: bytes, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> filepath = '/path/of/file' + >>> put(b'hello world', filepath) + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + backend.put(obj, filepath) + + +def put_text( + obj: str, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + ``filepath``. Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> filepath = '/path/of/file' + >>> put_text('hello world', filepath) + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + backend.put_text(obj, filepath) + + +def exists( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> exists(filepath) + True + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.exists(filepath) + + +def isdir( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/dir' + >>> isdir(filepath) + True + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.isdir(filepath) + + +def isfile( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> isfile(filepath) + True + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.isfile(filepath) + + +def join_path( + filepath: Union[str, Path], + *filepaths: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + *filepaths (str or Path): Other paths to be concatenated. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The result of concatenation. + + Examples: + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + return backend.join_path(filepath, *filepaths) + + +@contextmanager +def get_local_path( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself and it will + not be released (removed). + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: Only yield one path. + + Examples: + >>> with get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + with backend.get_local_path(str(filepath)) as local_path: + yield local_path + + +def copyfile( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError will + be raised. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> copyfile(src, dst) + '/path1/of/dir/file' + """ + backend = get_file_backend( + src, backend_args=backend_args, enable_singleton=True) + return backend.copyfile(src, dst) + + +def copytree( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will be + raised. + + Examples: + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> copytree(src, dst) + '/path/of/dir2' + """ + backend = get_file_backend( + src, backend_args=backend_args, enable_singleton=True) + return backend.copytree(src, dst) + + +def copyfile_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Copy a local file src to dst and return the destination file. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = 's3://openmmlab/mmengine/file1' + >>> # src will be copied to 's3://openmmlab/mmengine/file1' + >>> copyfile_from_local(src, dst) + s3://openmmlab/mmengine/file1 + + >>> # dst is a directory + >>> dst = 's3://openmmlab/mmengine' + >>> # src will be copied to 's3://openmmlab/mmengine/file'' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/file' + """ + backend = get_file_backend( + dst, backend_args=backend_args, enable_singleton=True) + return backend.copyfile_from_local(src, dst) + + +def copytree_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = '/path/of/dir' + >>> dst = 's3://openmmlab/mmengine/dir' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/dir' + """ + backend = get_file_backend( + dst, backend_args=backend_args, enable_singleton=True) + return backend.copytree_from_local(src, dst) + + +def copyfile_to_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Copy the file src to local dst and return the destination file. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = 's3://openmmlab/mmengine/file' + >>> dst = '/path/of/file' + >>> # src will be copied to '/path/of/file' + >>> copyfile_to_local(src, dst) + '/path/of/file' + + >>> # dst is a directory + >>> dst = '/path/of/dir' + >>> # src will be copied to '/path/of/dir/file' + >>> copyfile_to_local(src, dst) + '/path/of/dir/file' + """ + backend = get_file_backend( + dst, backend_args=backend_args, enable_singleton=True) + return backend.copyfile_to_local(src, dst) + + +def copytree_to_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = 's3://openmmlab/mmengine/dir' + >>> dst = '/path/of/dir' + >>> copytree_to_local(src, dst) + '/path/of/dir' + """ + backend = get_file_backend( + dst, backend_args=backend_args, enable_singleton=True) + return backend.copytree_to_local(src, dst) + + +def remove( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, +) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Raises: + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + + Examples: + >>> filepath = '/path/of/file' + >>> remove(filepath) + """ + backend = get_file_backend( + filepath, backend_args=backend_args, enable_singleton=True) + backend.remove(filepath) + + +def rmtree( + dir_path: Union[str, Path], + backend_args: Optional[dict] = None, +) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> dir_path = '/path/of/dir' + >>> rmtree(dir_path) + """ + backend = get_file_backend( + dir_path, backend_args=backend_args, enable_singleton=True) + backend.rmtree(dir_path) + + +def copy_if_symlink_fails( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, +) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directory copy src to + dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return True if successfully create a symbolic link pointing to + src. Otherwise, return False. + + Examples: + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> copy_if_symlink_fails(src, dst) + True + """ + backend = get_file_backend( + src, backend_args=backend_args, enable_singleton=True) + return backend.copy_if_symlink_fails(src, dst) + + +def list_dir_or_file( + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + backend_args: Optional[dict] = None, +) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> dir_path = '/path/of/dir' + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # list those files and directories in current directory + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ + backend = get_file_backend( + dir_path, backend_args=backend_args, enable_singleton=True) + yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + +def generate_presigned_url( + url: str, + client_method: str = 'get_object', + expires_in: int = 3600, + backend_args: Optional[dict] = None, +) -> str: + """Generate the presigned url of video stream which can be passed to + mmcv.VideoReader. Now only work on Petrel backend. + + Note: + Now only work on Petrel backend. + + Args: + url (str): Url of video stream. + client_method (str): Method of client, 'get_object' or + 'put_object'. Default: 'get_object'. + expires_in (int): expires, in seconds. Default: 3600. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: Generated presigned url. + """ + backend = get_file_backend( + url, backend_args=backend_args, enable_singleton=True) + return backend.generate_presigned_url(url, client_method, expires_in) + + +def load(file, + file_format=None, + file_client_args=None, + backend_args=None, + **kwargs): """Load data from json/yaml/pickle files. This method provides a unified api for loading data from serialized files. @@ -32,7 +808,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs): "pickle/pkl". file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Default: None. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> load('/path/of/your/file') # file is storaged in disk @@ -49,14 +829,28 @@ def load(file, file_format=None, file_client_args=None, **kwargs): if file_format not in file_handlers: raise TypeError(f'Unsupported format: {file_format}') + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args and "backend_args" cannot be set at the ' + 'same time.') + handler = file_handlers[file_format] if is_str(file): - file_client = FileClient.infer_client(file_client_args, file) + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend(file, backend_args=backend_args) + if handler.str_like: - with StringIO(file_client.get_text(file)) as f: + with StringIO(file_backend.get_text(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) else: - with BytesIO(file_client.get(file)) as f: + with BytesIO(file_backend.get(file)) as f: obj = handler.load_from_fileobj(f, **kwargs) elif hasattr(file, 'read'): obj = handler.load_from_fileobj(file, **kwargs) @@ -65,7 +859,12 @@ def load(file, file_format=None, file_client_args=None, **kwargs): return obj -def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): +def dump(obj, + file=None, + file_format=None, + file_client_args=None, + backend_args=None, + **kwargs): """Dump data to json/yaml/pickle strings or files. This method provides a unified api for dumping data as strings or to files, @@ -82,7 +881,11 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): file_format (str, optional): Same as :func:`load`. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Default: None. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> dump('hello world', '/path/of/your/file') # disk @@ -102,48 +905,34 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): if file_format not in file_handlers: raise TypeError(f'Unsupported format: {file_format}') + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set at the ' + 'same time.') + handler = file_handlers[file_format] if file is None: return handler.dump_to_str(obj, **kwargs) elif is_str(file): - file_client = FileClient.infer_client(file_client_args, file) + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend(file, backend_args=backend_args) + if handler.str_like: with StringIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) - file_client.put_text(f.getvalue(), file) + file_backend.put_text(f.getvalue(), file) else: with BytesIO() as f: handler.dump_to_fileobj(obj, f, **kwargs) - file_client.put(f.getvalue(), file) + file_backend.put(f.getvalue(), file) elif hasattr(file, 'write'): handler.dump_to_fileobj(obj, file, **kwargs) else: raise TypeError('"file" must be a filename str or a file-object') - - -def _register_handler(handler, file_formats): - """Register a handler for some file extensions. - - Args: - handler (:obj:`BaseFileHandler`): Handler to be registered. - file_formats (str or list[str]): File formats to be handled by this - handler. - """ - if not isinstance(handler, BaseFileHandler): - raise TypeError( - f'handler must be a child of BaseFileHandler, not {type(handler)}') - if isinstance(file_formats, str): - file_formats = [file_formats] - if not is_list_of(file_formats, str): - raise TypeError('file_formats must be a str or a list of str') - for ext in file_formats: - file_handlers[ext] = handler - - -def register_handler(file_formats, **kwargs): - - def wrap(cls): - _register_handler(cls(**kwargs), file_formats) - return cls - - return wrap diff --git a/mmengine/fileio/parse.py b/mmengine/fileio/parse.py index 8353b622..080ae023 100644 --- a/mmengine/fileio/parse.py +++ b/mmengine/fileio/parse.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from io import StringIO from .file_client import FileClient +from .io import get_text def list_from_file(filename, @@ -9,7 +11,8 @@ def list_from_file(filename, offset=0, max_num=0, encoding='utf-8', - file_client_args=None): + file_client_args=None, + backend_args=None): """Load a text file and parse the content as a list of strings. ``list_from_file`` supports loading a text file which can be storaged in @@ -21,10 +24,14 @@ def list_from_file(filename, offset (int): The offset of lines. max_num (int): The maximum number of lines to be read, zeros and negatives mean no limitation. - encoding (str): Encoding used to open the file. Default utf-8. + encoding (str): Encoding used to open the file. Defaults to utf-8. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Default: None. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> list_from_file('/path/of/your/file') # disk @@ -35,10 +42,24 @@ def list_from_file(filename, Returns: list[str]: A list of strings. """ + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set at the ' + 'same time.') cnt = 0 item_list = [] - file_client = FileClient.infer_client(file_client_args, filename) - with StringIO(file_client.get_text(filename, encoding)) as f: + + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, filename) + text = file_client.get_text(filename, encoding) + else: + text = get_text(filename, encoding, backend_args=backend_args) + + with StringIO(text) as f: for _ in range(offset): f.readline() for line in f: @@ -52,7 +73,8 @@ def list_from_file(filename, def dict_from_file(filename, key_type=str, encoding='utf-8', - file_client_args=None): + file_client_args=None, + backend_args=None): """Load a text file and parse the content as a dict. Each line of the text file will be two or more columns split by @@ -66,10 +88,14 @@ def dict_from_file(filename, filename(str): Filename. key_type(type): Type of the dict keys. str is user by default and type conversion will be performed if specified. - encoding (str): Encoding used to open the file. Default utf-8. + encoding (str): Encoding used to open the file. Defaults to utf-8. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Default: None. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> dict_from_file('/path/of/your/file') # disk @@ -80,9 +106,24 @@ def dict_from_file(filename, Returns: dict: The parsed contents. """ + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set at the ' + 'same time.') + mapping = {} - file_client = FileClient.infer_client(file_client_args, filename) - with StringIO(file_client.get_text(filename, encoding)) as f: + + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, filename) + text = file_client.get_text(filename, encoding) + else: + text = get_text(filename, encoding, backend_args=backend_args) + + with StringIO(text) as f: for line in f: items = line.rstrip('\n').split() assert len(items) >= 2 diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 7c1ca50b..cbd13124 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Union from mmengine.dist import master_only -from mmengine.fileio import FileClient +from mmengine.fileio import FileClient, get_file_backend from mmengine.registry import HOOKS from mmengine.utils import is_list_of, is_seq_of from .hook import Hook @@ -72,14 +72,18 @@ class CheckpointHook(Hook): inferred by 'less' comparison rule. If ``None``, _default_less_keys will be used. Defaults to None. file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmcv.fileio.FileClient` for details. - Defaults to None. + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. filename_tmpl (str, optional): String template to indicate checkpoint name. If specified, must contain one and only one "{}", which will be replaced with ``epoch + 1`` if ``by_epoch=True`` else ``iteration + 1``. Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" accordingly. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> # Save best based on single metric @@ -123,6 +127,7 @@ class CheckpointHook(Hook): less_keys: Optional[Sequence[str]] = None, file_client_args: Optional[dict] = None, filename_tmpl: Optional[str] = None, + backend_args: Optional[dict] = None, **kwargs) -> None: self.interval = interval self.by_epoch = by_epoch @@ -131,7 +136,20 @@ class CheckpointHook(Hook): self.out_dir = out_dir # type: ignore self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last + self.args = kwargs + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set ' + 'at the same time.') + self.file_client_args = file_client_args + self.backend_args = backend_args + if filename_tmpl is None: if self.by_epoch: self.filename_tmpl = 'epoch_{}.pth' @@ -139,7 +157,6 @@ class CheckpointHook(Hook): self.filename_tmpl = 'iter_{}.pth' else: self.filename_tmpl = filename_tmpl - self.args = kwargs # save best logic assert (isinstance(save_best, str) or is_list_of(save_best, str) @@ -211,19 +228,28 @@ class CheckpointHook(Hook): if self.out_dir is None: self.out_dir = runner.work_dir + # If self.file_client_args is None, self.file_client will not + # used in CheckpointHook. To avoid breaking backward compatibility, + # it will not be removed util the release of MMEngine1.0 self.file_client = FileClient.infer_client(self.file_client_args, self.out_dir) + + if self.file_client_args is None: + self.file_backend = get_file_backend( + self.out_dir, backend_args=self.backend_args) + else: + self.file_backend = self.file_client + # if `self.out_dir` is not equal to `runner.work_dir`, it means that # `self.out_dir` is set so the final `self.out_dir` is the # concatenation of `self.out_dir` and the last level directory of # `runner.work_dir` if self.out_dir != runner.work_dir: basename = osp.basename(runner.work_dir.rstrip(osp.sep)) - self.out_dir = self.file_client.join_path( + self.out_dir = self.file_backend.join_path( self.out_dir, basename) # type: ignore # noqa: E501 - runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by ' - f'{self.file_client.name}.') + runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.') if self.save_best is not None: if len(self.key_indicators) == 1: @@ -302,11 +328,12 @@ class CheckpointHook(Hook): save_optimizer=self.save_optimizer, save_param_scheduler=self.save_param_scheduler, by_epoch=self.by_epoch, + backend_args=self.backend_args, **self.args) runner.message_hub.update_info( - 'last_ckpt', self.file_client.join_path(self.out_dir, - ckpt_filename)) + 'last_ckpt', + self.file_backend.join_path(self.out_dir, ckpt_filename)) # remove other checkpoints if self.max_keep_ckpts > 0: @@ -318,16 +345,15 @@ class CheckpointHook(Hook): current_ckpt - self.max_keep_ckpts * self.interval, 0, -self.interval) for _step in redundant_ckpts: - ckpt_path = self.file_client.join_path( + ckpt_path = self.file_backend.join_path( self.out_dir, self.filename_tmpl.format(_step)) - if self.file_client.isfile(ckpt_path): - self.file_client.remove(ckpt_path) + if self.file_backend.isfile(ckpt_path): + self.file_backend.remove(ckpt_path) else: break save_file = osp.join(runner.work_dir, 'last_checkpoint') - file_client = FileClient.infer_client(uri=self.out_dir) - filepath = file_client.join_path(self.out_dir, ckpt_filename) + filepath = self.file_backend.join_path(self.out_dir, ckpt_filename) with open(save_file, 'w') as f: f.write(filepath) @@ -404,7 +430,8 @@ class CheckpointHook(Hook): file_client_args=self.file_client_args, save_optimizer=False, save_param_scheduler=False, - by_epoch=False) + by_epoch=False, + backend_args=self.backend_args) runner.logger.info( f'The best checkpoint with {best_score:0.4f} {key_indicator} ' f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index c189c2d1..71752be1 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp +import warnings from pathlib import Path from typing import Dict, Optional, Sequence, Union from mmengine.fileio import FileClient, dump +from mmengine.fileio.io import get_file_backend from mmengine.hooks import Hook from mmengine.registry import HOOKS from mmengine.utils import is_tuple_of, scandir @@ -50,12 +52,16 @@ class LoggerHook(Hook): removed. Defaults to True. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. + Defaults to None. It will be deprecated in future. Please use + `backend_args` instead. log_metric_by_epoch (bool): Whether to output metric in validation step by epoch. It can be true when running in epoch based runner. If set to True, `after_val_epoch` will set `step` to self.epoch in `runner.visualizer.add_scalars`. Otherwise `step` will be self.iter. Default to True. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. Examples: >>> # The simplest LoggerHook config. @@ -71,7 +77,8 @@ class LoggerHook(Hook): out_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'), keep_local: bool = True, file_client_args: Optional[dict] = None, - log_metric_by_epoch: bool = True): + log_metric_by_epoch: bool = True, + backend_args: Optional[dict] = None): self.interval = interval self.ignore_last = ignore_last self.interval_exp_name = interval_exp_name @@ -82,6 +89,15 @@ class LoggerHook(Hook): 'specified.') self.out_dir = out_dir + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set ' + 'at the same time.') + if not (out_dir is None or isinstance(out_dir, str) or is_tuple_of(out_dir, str)): raise TypeError('out_dir should be None or string or tuple of ' @@ -91,9 +107,16 @@ class LoggerHook(Hook): self.keep_local = keep_local self.file_client_args = file_client_args self.json_log_path: Optional[str] = None + if self.out_dir is not None: self.file_client = FileClient.infer_client(file_client_args, self.out_dir) + if file_client_args is None: + self.file_backend = get_file_backend( + self.out_dir, backend_args=backend_args) + else: + self.file_backend = self.file_client + self.log_metric_by_epoch = log_metric_by_epoch def before_run(self, runner) -> None: @@ -107,10 +130,10 @@ class LoggerHook(Hook): # The final `self.out_dir` is the concatenation of `self.out_dir` # and the last level directory of `runner.work_dir` basename = osp.basename(runner.work_dir.rstrip(osp.sep)) - self.out_dir = self.file_client.join_path(self.out_dir, basename) + self.out_dir = self.file_backend.join_path(self.out_dir, basename) runner.logger.info( - f'Text logs will be saved to {self.out_dir} by ' - f'{self.file_client.name} after the training process.') + f'Text logs will be saved to {self.out_dir} after the ' + 'training process.') self.json_log_path = f'{runner.timestamp}.json' @@ -245,9 +268,9 @@ class LoggerHook(Hook): return for filename in scandir(runner._log_dir, self.out_suffix, True): local_filepath = osp.join(runner._log_dir, filename) - out_filepath = self.file_client.join_path(self.out_dir, filename) + out_filepath = self.file_backend.join_path(self.out_dir, filename) with open(local_filepath) as f: - self.file_client.put_text(f.read(), out_filepath) + self.file_backend.put_text(f.read(), out_filepath) runner.logger.info( f'The file {local_filepath} has been uploaded to ' diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 9b80c4e0..9fe42b6d 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -15,7 +15,7 @@ import torchvision import mmengine from mmengine.dist import get_dist_info -from mmengine.fileio import FileClient +from mmengine.fileio import FileClient, get_file_backend from mmengine.fileio import load as load_file from mmengine.logging import print_log from mmengine.model import is_model_wrapper @@ -334,7 +334,8 @@ def load_from_pavi(filename, map_location=None): return checkpoint -@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://') +@CheckpointLoader.register_scheme( + prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) def load_from_ceph(filename, map_location=None, backend='petrel'): """load checkpoint through the file path prefixed with s3. In distributed setting, this function download ckpt at all ranks to different temporary @@ -343,35 +344,15 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): Args: filename (str): checkpoint file path with s3 prefix map_location (str, optional): Same as :func:`torch.load`. - backend (str, optional): The storage backend type. Options are 'ceph', - 'petrel'. Default: 'petrel'. - - .. warning:: - :class:`mmengine.fileio.file_client.CephBackend` will be deprecated, - please use :class:`mmengine.fileio.file_client.PetrelBackend` instead. + backend (str, optional): The storage backend type. + Defaults to 'petrel'. Returns: dict or OrderedDict: The loaded checkpoint. """ - allowed_backends = ['ceph', 'petrel'] - if backend not in allowed_backends: - raise ValueError(f'Load from Backend {backend} is not supported.') - - if backend == 'ceph': - warnings.warn( - 'CephBackend will be deprecated, please use PetrelBackend instead', - DeprecationWarning) - - # CephClient and PetrelBackend have the same prefix 's3://' and the latter - # will be chosen as default. If PetrelBackend can not be instantiated - # successfully, the CephClient will be chosen. - try: - file_client = FileClient(backend=backend) - except ImportError: - allowed_backends.remove(backend) - file_client = FileClient(backend=allowed_backends[0]) - - with io.BytesIO(file_client.get(filename)) as buffer: + file_backend = get_file_backend( + filename, backend_args={'backend': backend}) + with io.BytesIO(file_backend.get(filename)) as buffer: checkpoint = torch.load(buffer, map_location=map_location) return checkpoint @@ -657,7 +638,10 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): return destination -def save_checkpoint(checkpoint, filename, file_client_args=None): +def save_checkpoint(checkpoint, + filename, + file_client_args=None, + backend_args=None): """Save checkpoint to file. Args: @@ -665,13 +649,26 @@ def save_checkpoint(checkpoint, filename, file_client_args=None): filename (str): Checkpoint filename. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. + Defaults to None. It will be deprecated in future. Please use + `backend_args` instead. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. """ + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set ' + 'at the same time.') + if filename.startswith('pavi://'): - if file_client_args is not None: + if file_client_args is not None or backend_args is not None: raise ValueError( - 'file_client_args should be "None" if filename starts with' - f'"pavi://", but got {file_client_args}') + '"file_client_args" or "backend_args" should be "None" if ' + 'filename starts with "pavi://"') try: from pavi import exception, modelcloud except ImportError: @@ -692,9 +689,15 @@ def save_checkpoint(checkpoint, filename, file_client_args=None): model.create_file(checkpoint_file, name=model_name) else: file_client = FileClient.infer_client(file_client_args, filename) + if file_client_args is None: + file_backend = get_file_backend( + filename, backend_args=backend_args) + else: + file_backend = file_client + with io.BytesIO() as f: torch.save(checkpoint, f) - file_client.put(f.getvalue(), filename) + file_backend.put(f.getvalue(), filename) def find_latest_checkpoint(path: str) -> Optional[str]: diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index fb292af4..5d905ae2 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -24,7 +24,7 @@ from mmengine.device import get_device from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, is_distributed, master_only) from mmengine.evaluator import Evaluator -from mmengine.fileio import FileClient +from mmengine.fileio import FileClient, join_path from mmengine.hooks import Hook from mmengine.logging import MessageHub, MMLogger, print_log from mmengine.model import (BaseModel, MMDistributedDataParallel, @@ -2005,14 +2005,17 @@ class Runner: return checkpoint @master_only - def save_checkpoint(self, - out_dir: str, - filename: str, - file_client_args: Optional[dict] = None, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - meta: dict = None, - by_epoch: bool = True): + def save_checkpoint( + self, + out_dir: str, + filename: str, + file_client_args: Optional[dict] = None, + save_optimizer: bool = True, + save_param_scheduler: bool = True, + meta: dict = None, + by_epoch: bool = True, + backend_args: Optional[dict] = None, + ): """Save checkpoints. ``CheckpointHook`` invokes this method to save checkpoints @@ -2022,7 +2025,9 @@ class Runner: out_dir (str): The directory that checkpoints are saved. filename (str): The checkpoint filename. file_client_args (dict, optional): Arguments to instantiate a - FileClient. Default: None. + FileClient. See :class:`mmengine.fileio.FileClient` for + details. Defaults to None. It will be deprecated in future. + Please use `backend_args` instead. save_optimizer (bool): Whether to save the optimizer to the checkpoint. Defaults to True. save_param_scheduler (bool): Whether to save the param_scheduler @@ -2031,6 +2036,9 @@ class Runner: checkpoint. Defaults to None. by_epoch (bool): Whether the scheduled momentum is updated by epochs. Defaults to True. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + New in v0.2.0. """ if meta is None: meta = {} @@ -2047,8 +2055,20 @@ class Runner: else: meta.update(epoch=self.epoch, iter=self.iter + 1) - file_client = FileClient.infer_client(file_client_args, out_dir) - filepath = file_client.join_path(out_dir, filename) + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set at ' + 'the same time.') + + file_client = FileClient.infer_client(file_client_args, out_dir) + filepath = file_client.join_path(out_dir, filename) + else: + filepath = join_path( # type: ignore + out_dir, filename, backend_args=backend_args) meta.update( cfg=self.cfg.pretty_text, diff --git a/requirements/tests.txt b/requirements/tests.txt index c5043abf..debf7eb1 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,3 +1,4 @@ coverage lmdb +parameterized pytest diff --git a/tests/test_fileio/test_backends/test_backend_utils.py b/tests/test_fileio/test_backends/test_backend_utils.py new file mode 100644 index 00000000..7903f557 --- /dev/null +++ b/tests/test_fileio/test_backends/test_backend_utils.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine.fileio.backends import (BaseStorageBackend, backends, + prefix_to_backends, register_backend) + + +def test_register_backend(): + # 1. two ways to register backend + # 1.1 use it as a decorator + @register_backend('example') + class ExampleBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + assert 'example' in backends + + # 1.2 use it as a normal function + class ExampleBackend1(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + register_backend('example1', ExampleBackend1) + assert 'example1' in backends + + # 2. test `name` parameter + # 2. name should a string + with pytest.raises(TypeError, match='name should be a string'): + register_backend(1, ExampleBackend) + + register_backend('example2', ExampleBackend) + assert 'example2' in backends + + # 3. test `backend` parameter + # If backend is not None, it should be a class and a subclass of + # BaseStorageBackend. + with pytest.raises(TypeError, match='backend should be a class'): + + def test_backend(): + pass + + register_backend('example3', test_backend) + + class ExampleBackend2: + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + with pytest.raises( + TypeError, match='not a subclass of BaseStorageBackend'): + register_backend('example3', ExampleBackend2) + + # 4. test `force` parameter + # 4.1 force=False + with pytest.raises(ValueError, match='example is already registered'): + register_backend('example', ExampleBackend) + + # 4.2 force=True + register_backend('example', ExampleBackend, force=True) + assert 'example' in backends + + # 5. test `prefixes` parameter + class ExampleBackend3(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + # 5.1 prefixes is a string + register_backend('example3', ExampleBackend3, prefixes='prefix1') + assert 'example3' in backends + assert 'prefix1' in prefix_to_backends + + # 5.2 prefixes is a list (tuple) of strings + register_backend( + 'example4', ExampleBackend3, prefixes=['prefix2', 'prefix3']) + assert 'example4' in backends + assert 'prefix2' in prefix_to_backends + assert 'prefix3' in prefix_to_backends + assert prefix_to_backends['prefix2'] == prefix_to_backends['prefix3'] + + # 5.3 prefixes is an invalid type + with pytest.raises(AssertionError): + register_backend('example5', ExampleBackend3, prefixes=1) + + # 5.4 prefixes is already registered + with pytest.raises(ValueError, match='prefix2 is already registered'): + register_backend('example6', ExampleBackend3, prefixes='prefix2') + + class ExampleBackend4(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + register_backend( + 'example6', ExampleBackend4, prefixes='prefix2', force=True) + assert 'example6' in backends + assert 'prefix2' in prefix_to_backends diff --git a/tests/test_fileio/test_backends/test_base_storage_backend.py b/tests/test_fileio/test_backends/test_base_storage_backend.py new file mode 100644 index 00000000..6aa60885 --- /dev/null +++ b/tests/test_fileio/test_backends/test_base_storage_backend.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine.fileio.backends import BaseStorageBackend + + +def test_base_storage_backend(): + # test inheritance + class ExampleBackend(BaseStorageBackend): + pass + + with pytest.raises( + TypeError, + match="Can't instantiate abstract class ExampleBackend"): + ExampleBackend() + + class ExampleBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + backend = ExampleBackend() + assert backend.get('test') == 'test' + assert backend.get_text('test') == 'test' diff --git a/tests/test_fileio/test_backends/test_http_backend.py b/tests/test_fileio/test_backends/test_http_backend.py new file mode 100644 index 00000000..c69394d1 --- /dev/null +++ b/tests/test_fileio/test_backends/test_http_backend.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from unittest import TestCase + +import cv2 +import numpy as np + +from mmengine.fileio.backends import HTTPBackend + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img + + +def imread(path): + with open(path, 'rb') as f: + content = f.read() + img = imfrombytes(content) + return img + + +class TestHTTPBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.img_url = ( + 'https://download.openmmlab.com/mmengine/test-data/color.jpg') + cls.img_shape = (300, 400, 3) + cls.text_url = ( + 'https://download.openmmlab.com/mmengine/test-data/filelist.txt') + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.text_path = cls.test_data_dir / 'filelist.txt' + + def test_get(self): + backend = HTTPBackend() + img_bytes = backend.get(self.img_url) + img = imfrombytes(img_bytes) + self.assertEqual(img.shape, self.img_shape) + + def test_get_text(self): + backend = HTTPBackend() + text = backend.get_text(self.text_url) + self.assertEqual(self.text_path.open('r').read(), text) + + def test_get_local_path(self): + backend = HTTPBackend() + with backend.get_local_path(self.img_url) as filepath: + img = imread(filepath) + self.assertEqual(img.shape, self.img_shape) diff --git a/tests/test_fileio/test_backends/test_lmdb_backend.py b/tests/test_fileio/test_backends/test_lmdb_backend.py new file mode 100644 index 00000000..dc2c7ded --- /dev/null +++ b/tests/test_fileio/test_backends/test_lmdb_backend.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from unittest import TestCase + +import cv2 +import numpy as np +from parameterized import parameterized + +from mmengine.fileio.backends import LmdbBackend + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img + + +class TestLmdbBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.lmdb_path = cls.test_data_dir / 'demo.lmdb' + + @parameterized.expand([[Path], [str]]) + def test_get(self, path_type): + backend = LmdbBackend(path_type(self.lmdb_path)) + img_bytes = backend.get('baboon') + img = imfrombytes(img_bytes) + self.assertEqual(img.shape, (120, 125, 3)) + + def test_get_text(self): + backend = LmdbBackend(self.lmdb_path) + with self.assertRaises(NotImplementedError): + backend.get_text('filepath') diff --git a/tests/test_fileio/test_backends/test_local_backend.py b/tests/test_fileio/test_backends/test_local_backend.py new file mode 100644 index 00000000..427ebf78 --- /dev/null +++ b/tests/test_fileio/test_backends/test_local_backend.py @@ -0,0 +1,486 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import platform +import tempfile +from contextlib import contextmanager +from pathlib import Path +from shutil import SameFileError +from unittest import TestCase +from unittest.mock import patch + +import cv2 +import numpy as np +from parameterized import parameterized + +from mmengine.fileio.backends import LocalBackend + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img + + +@contextmanager +def build_temporary_directory(): + """Build a temporary directory containing many files to test + ``FileClient.list_dir_or_file``. + + . \n + | -- dir1 \n + | -- | -- text3.txt \n + | -- dir2 \n + | -- | -- dir3 \n + | -- | -- | -- text4.txt \n + | -- | -- img.jpg \n + | -- text1.txt \n + | -- text2.txt \n + """ + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + yield tmp_dir + + +class TestLocalBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.img_path = cls.test_data_dir / 'color.jpg' + cls.img_shape = (300, 400, 3) + cls.text_path = cls.test_data_dir / 'filelist.txt' + + def test_name(self): + backend = LocalBackend() + self.assertEqual(backend.name, 'LocalBackend') + + @parameterized.expand([[Path], [str]]) + def test_get(self, path_type): + backend = LocalBackend() + img_bytes = backend.get(path_type(self.img_path)) + self.assertEqual(self.img_path.open('rb').read(), img_bytes) + img = imfrombytes(img_bytes) + self.assertEqual(img.shape, self.img_shape) + + @parameterized.expand([[Path], [str]]) + def test_get_text(self, path_type): + backend = LocalBackend() + text = backend.get_text(path_type(self.text_path)) + self.assertEqual(self.text_path.open('r').read(), text) + + @parameterized.expand([[Path], [str]]) + def test_put(self, path_type): + backend = LocalBackend() + + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = Path(tmp_dir) / 'test.jpg' + backend.put(b'disk', path_type(filepath)) + self.assertEqual(backend.get(filepath), b'disk') + + # If the directory does not exist, put will create a + # directory first + filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.jpg' + backend.put(b'disk', path_type(filepath)) + self.assertEqual(backend.get(filepath), b'disk') + + @parameterized.expand([[Path], [str]]) + def test_put_text(self, path_type): + backend = LocalBackend() + + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', path_type(filepath)) + self.assertEqual(backend.get_text(filepath), 'disk') + + # If the directory does not exist, put_text will create a + # directory first + filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.txt' + backend.put_text('disk', path_type(filepath)) + self.assertEqual(backend.get_text(filepath), 'disk') + + @parameterized.expand([[Path], [str]]) + def test_exists(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + self.assertTrue(backend.exists(path_type(tmp_dir))) + filepath = Path(tmp_dir) / 'test.txt' + self.assertFalse(backend.exists(path_type(filepath))) + backend.put_text('disk', filepath) + self.assertTrue(backend.exists(path_type(filepath))) + backend.remove(filepath) + + @parameterized.expand([[Path], [str]]) + def test_isdir(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + self.assertTrue(backend.isdir(path_type(tmp_dir))) + filepath = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', filepath) + self.assertFalse(backend.isdir(path_type(filepath))) + + @parameterized.expand([[Path], [str]]) + def test_isfile(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + self.assertFalse(backend.isfile(path_type(tmp_dir))) + filepath = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', filepath) + self.assertTrue(backend.isfile(path_type(filepath))) + + @parameterized.expand([[Path], [str]]) + def test_join_path(self, path_type): + backend = LocalBackend() + filepath = backend.join_path( + path_type(self.test_data_dir), path_type('file')) + expected = osp.join(path_type(self.test_data_dir), path_type('file')) + self.assertEqual(filepath, expected) + + filepath = backend.join_path( + path_type(self.test_data_dir), path_type('dir'), path_type('file')) + expected = osp.join( + path_type(self.test_data_dir), path_type('dir'), path_type('file')) + self.assertEqual(filepath, expected) + + @parameterized.expand([[Path], [str]]) + def test_get_local_path(self, path_type): + backend = LocalBackend() + with backend.get_local_path(path_type(self.text_path)) as filepath: + self.assertEqual(path_type(self.text_path), path_type(filepath)) + + @parameterized.expand([[Path], [str]]) + def test_copyfile(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) + self.assertEqual(backend.get_text(dst), 'disk') + + # dst is a directory + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual( + backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(path_type(src), path_type(src)) + + @parameterized.expand([[Path], [str]]) + def test_copytree(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + self.assertEqual( + backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) + self.assertTrue(backend.isdir(dst)) + self.assertTrue(backend.isfile(dst / 'text3.txt')) + self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + + # dst should not exist + with self.assertRaises(FileExistsError): + backend.copytree( + path_type(src), path_type(Path(tmp_dir) / 'dir2')) + + @parameterized.expand([[Path], [str]]) + def test_copyfile_from_local(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) + self.assertEqual(backend.get_text(dst), 'disk') + + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual( + backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(path_type(src), path_type(src)) + + @parameterized.expand([[Path], [str]]) + def test_copytree_from_local(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + self.assertEqual( + backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) + self.assertTrue(backend.isdir(dst)) + self.assertTrue(backend.isfile(dst / 'text3.txt')) + self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + + # dst should not exist + with self.assertRaises(FileExistsError): + backend.copytree( + path_type(src), path_type(Path(tmp_dir) / 'dir2')) + + @parameterized.expand([[Path], [str]]) + def test_copyfile_to_local(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) + self.assertEqual(backend.get_text(dst), 'disk') + + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual( + backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') + + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(path_type(src), path_type(src)) + + @parameterized.expand([[Path], [str]]) + def test_copytree_to_local(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + self.assertEqual( + backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) + self.assertTrue(backend.isdir(dst)) + self.assertTrue(backend.isfile(dst / 'text3.txt')) + self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') + + # dst should not exist + with self.assertRaises(FileExistsError): + backend.copytree( + path_type(src), path_type(Path(tmp_dir) / 'dir2')) + + @parameterized.expand([[Path], [str]]) + def test_remove(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + # filepath is a Path object + filepath = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', filepath) + self.assertTrue(backend.exists(filepath)) + backend.remove(path_type(filepath)) + self.assertFalse(backend.exists(filepath)) + + # raise error if file does not exist + with self.assertRaises(FileNotFoundError): + filepath = Path(tmp_dir) / 'test1.txt' + backend.remove(path_type(filepath)) + + # can not remove directory + filepath = Path(tmp_dir) / 'dir' + filepath.mkdir() + with self.assertRaises(IsADirectoryError): + backend.remove(path_type(filepath)) + + @parameterized.expand([[Path], [str]]) + def test_rmtree(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + dir_path = Path(tmp_dir) / 'dir1' + self.assertTrue(backend.exists(dir_path)) + backend.rmtree(path_type(dir_path)) + self.assertFalse(backend.exists(dir_path)) + + dir_path = Path(tmp_dir) / 'dir2' + self.assertTrue(backend.exists(dir_path)) + backend.rmtree(path_type(dir_path)) + self.assertFalse(backend.exists(dir_path)) + + @parameterized.expand([[Path], [str]]) + def test_copy_if_symlink_fails(self, path_type): + backend = LocalBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + # create a symlink for a file + src = Path(tmp_dir) / 'test.txt' + backend.put_text('disk', src) + dst = Path(tmp_dir) / 'test_link.txt' + res = backend.copy_if_symlink_fails(path_type(src), path_type(dst)) + if platform.system() == 'Linux': + self.assertTrue(res) + self.assertTrue(osp.islink(dst)) + self.assertEqual(backend.get_text(dst), 'disk') + + # create a symlink for a directory + src = Path(tmp_dir) / 'dir' + src.mkdir() + dst = Path(tmp_dir) / 'dir_link' + res = backend.copy_if_symlink_fails(path_type(src), path_type(dst)) + if platform.system() == 'Linux': + self.assertTrue(res) + self.assertTrue(osp.islink(dst)) + self.assertTrue(backend.exists(dst)) + + def symlink(src, dst): + raise Exception + + # copy files if symblink fails + with patch.object(os, 'symlink', side_effect=symlink): + src = Path(tmp_dir) / 'test.txt' + dst = Path(tmp_dir) / 'test_link1.txt' + res = backend.copy_if_symlink_fails( + path_type(src), path_type(dst)) + self.assertFalse(res) + self.assertFalse(osp.islink(dst)) + self.assertTrue(backend.exists(dst)) + + # copy directory if symblink fails + with patch.object(os, 'symlink', side_effect=symlink): + src = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / 'dir_link1' + res = backend.copy_if_symlink_fails( + path_type(src), path_type(dst)) + self.assertFalse(res) + self.assertFalse(osp.islink(dst)) + self.assertTrue(backend.exists(dst)) + + @parameterized.expand([[Path], [str]]) + def test_list_dir_or_file(self, path_type): + backend = LocalBackend() + with build_temporary_directory() as tmp_dir: + # list directories and files + self.assertEqual( + set(backend.list_dir_or_file(path_type(tmp_dir))), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + + # list directories and files recursively + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), recursive=True)), + { + 'dir1', + osp.join('dir1', 'text3.txt'), 'dir2', + osp.join('dir2', 'dir3'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + }) + + # only list directories + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_file=False)), + {'dir1', 'dir2'}) + + with self.assertRaisesRegex( + TypeError, + '`suffix` should be None when `list_dir` is True'): + backend.list_dir_or_file( + path_type(tmp_dir), list_file=False, suffix='.txt') + + # only list directories recursively + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_file=False, recursive=True)), + {'dir1', 'dir2', osp.join('dir2', 'dir3')}) + + # only list files + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_dir=False)), + {'text1.txt', 'text2.txt'}) + + # only list files recursively + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_dir=False, recursive=True)), + { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), list_dir=False, suffix='.txt')), + {'text1.txt', 'text2.txt'}) + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), + list_dir=False, + suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) + + with self.assertRaisesRegex( + TypeError, + '`suffix` must be a string or tuple of strings'): + backend.list_dir_or_file( + path_type(tmp_dir), + list_dir=False, + suffix=['.txt', '.jpg']) + + # only list files ending with suffix recursively + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), + list_dir=False, + suffix='.txt', + recursive=True)), { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + path_type(tmp_dir), + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), + { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + }) diff --git a/tests/test_fileio/test_backends/test_memcached_backend.py b/tests/test_fileio/test_backends/test_memcached_backend.py new file mode 100644 index 00000000..d320fcb1 --- /dev/null +++ b/tests/test_fileio/test_backends/test_memcached_backend.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from pathlib import Path +from unittest import TestCase +from unittest.mock import MagicMock, patch + +import cv2 +import numpy as np +from parameterized import parameterized + +from mmengine.fileio.backends import MemcachedBackend + + +def imfrombytes(content): + img_np = np.frombuffer(content, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + return img + + +sys.modules['mc'] = MagicMock() + + +class MockMemcachedClient: + + def __init__(self, server_list_cfg, client_cfg): + pass + + def Get(self, filepath, buffer): + with open(filepath, 'rb') as f: + buffer.content = f.read() + + +class TestMemcachedBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None) + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.img_path = cls.test_data_dir / 'color.jpg' + cls.img_shape = (300, 400, 3) + + @parameterized.expand([[Path], [str]]) + @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) + @patch('mc.pyvector', MagicMock) + @patch('mc.ConvertBuffer', lambda x: x.content) + def test_get(self, path_type): + backend = MemcachedBackend(**self.mc_cfg) + img_bytes = backend.get(path_type(self.img_path)) + self.assertEqual(self.img_path.open('rb').read(), img_bytes) + img = imfrombytes(img_bytes) + self.assertEqual(img.shape, self.img_shape) + + @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient) + @patch('mc.pyvector', MagicMock) + @patch('mc.ConvertBuffer', lambda x: x.content) + def test_get_text(self): + backend = MemcachedBackend(**self.mc_cfg) + with self.assertRaises(NotImplementedError): + backend.get_text('filepath') diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py new file mode 100644 index 00000000..63c9284a --- /dev/null +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -0,0 +1,858 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import sys +import tempfile +from contextlib import contextmanager +from copy import deepcopy +from pathlib import Path +from shutil import SameFileError +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from mmengine.fileio.backends import PetrelBackend +from mmengine.utils import has_method + + +@contextmanager +def build_temporary_directory(): + """Build a temporary directory containing many files to test + ``FileClient.list_dir_or_file``. + + . \n + | -- dir1 \n + | -- | -- text3.txt \n + | -- dir2 \n + | -- | -- dir3 \n + | -- | -- | -- text4.txt \n + | -- | -- img.jpg \n + | -- text1.txt \n + | -- text2.txt \n + """ + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + yield tmp_dir + + +try: + # Other unit tests may mock these modules so we need to pop them first. + sys.modules.pop('petrel_client', None) + sys.modules.pop('petrel_client.client', None) + + # If petrel_client is imported successfully, we can test PetrelBackend + # without mock. + import petrel_client # noqa: F401 +except ImportError: + sys.modules['petrel_client'] = MagicMock() + sys.modules['petrel_client.client'] = MagicMock() + + class MockPetrelClient: + + def __init__(self, enable_mc=True, enable_multi_cluster=False): + self.enable_mc = enable_mc + self.enable_multi_cluster = enable_multi_cluster + + def Get(self, filepath): + with open(filepath, 'rb') as f: + content = f.read() + return content + + def put(self): + pass + + def delete(self): + pass + + def contains(self): + pass + + def isdir(self): + pass + + def list(self, dir_path): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + yield entry.name + elif osp.isdir(entry.path): + yield entry.name + '/' + + @contextmanager + def delete_and_reset_method(obj, method): + method_obj = deepcopy(getattr(type(obj), method)) + try: + delattr(type(obj), method) + yield + finally: + setattr(type(obj), method, method_obj) + + @patch('petrel_client.client.Client', MockPetrelClient) + class TestPetrelBackend(TestCase): + + @classmethod + def setUpClass(cls): + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.img_path = cls.test_data_dir / 'color.jpg' + cls.img_shape = (300, 400, 3) + cls.text_path = cls.test_data_dir / 'filelist.txt' + cls.petrel_dir = 'petrel://user/data' + cls.petrel_path = f'{cls.petrel_dir}/test.jpg' + cls.expected_dir = 's3://user/data' + cls.expected_path = f'{cls.expected_dir}/test.jpg' + + def test_name(self): + backend = PetrelBackend() + self.assertEqual(backend.name, 'PetrelBackend') + + def test_map_path(self): + backend = PetrelBackend(path_mapping=None) + self.assertEqual( + backend._map_path(self.petrel_path), self.petrel_path) + + backend = PetrelBackend( + path_mapping={'data/': 'petrel://user/data/'}) + self.assertEqual( + backend._map_path('data/test.jpg'), self.petrel_path) + + def test_format_path(self): + backend = PetrelBackend() + formatted_filepath = backend._format_path( + 'petrel://user\\data\\test.jpg') + self.assertEqual(formatted_filepath, self.petrel_path) + + def test_replace_prefix(self): + backend = PetrelBackend() + self.assertEqual( + backend._replace_prefix(self.petrel_path), self.expected_path) + + def test_join_path(self): + backend = PetrelBackend() + self.assertEqual( + backend.join_path(self.petrel_dir, 'file'), + f'{self.petrel_dir}/file') + self.assertEqual( + backend.join_path(f'{self.petrel_dir}/', 'file'), + f'{self.petrel_dir}/file') + self.assertEqual( + backend.join_path(f'{self.petrel_dir}/', '/file'), + f'{self.petrel_dir}/file') + self.assertEqual( + backend.join_path(self.petrel_dir, 'dir', 'file'), + f'{self.petrel_dir}/dir/file') + + def test_get(self): + backend = PetrelBackend() + with patch.object( + backend._client, 'Get', + return_value=b'petrel') as patched_get: + self.assertEqual(backend.get(self.petrel_path), b'petrel') + patched_get.assert_called_once_with(self.expected_path) + + def test_get_text(self): + backend = PetrelBackend() + with patch.object( + backend._client, 'Get', + return_value=b'petrel') as patched_get: + self.assertEqual(backend.get_text(self.petrel_path), 'petrel') + patched_get.assert_called_once_with(self.expected_path) + + def test_put(self): + backend = PetrelBackend() + with patch.object(backend._client, 'put') as patched_put: + backend.put(b'petrel', self.petrel_path) + patched_put.assert_called_once_with(self.expected_path, + b'petrel') + + def test_put_text(self): + backend = PetrelBackend() + with patch.object(backend._client, 'put') as patched_put: + backend.put_text('petrel', self.petrel_path) + patched_put.assert_called_once_with(self.expected_path, + b'petrel') + + def test_exists(self): + backend = PetrelBackend() + self.assertTrue(has_method(backend._client, 'contains')) + self.assertTrue(has_method(backend._client, 'isdir')) + # raise Exception if `_client.contains` and '_client.isdir' are not + # implemented + with delete_and_reset_method(backend._client, 'contains'), \ + delete_and_reset_method(backend._client, 'isdir'): + self.assertFalse(has_method(backend._client, 'contains')) + self.assertFalse(has_method(backend._client, 'isdir')) + with self.assertRaises(NotImplementedError): + backend.exists(self.petrel_path) + + with patch.object( + backend._client, 'contains', + return_value=True) as patched_contains: + self.assertTrue(backend.exists(self.petrel_path)) + patched_contains.assert_called_once_with(self.expected_path) + + def test_isdir(self): + backend = PetrelBackend() + self.assertTrue(has_method(backend._client, 'isdir')) + # raise Exception if `_client.isdir` is not implemented + with delete_and_reset_method(backend._client, 'isdir'): + self.assertFalse(has_method(backend._client, 'isdir')) + with self.assertRaises(NotImplementedError): + backend.isdir(self.petrel_path) + + with patch.object( + backend._client, 'isdir', + return_value=True) as patched_contains: + self.assertTrue(backend.isdir(self.petrel_path)) + patched_contains.assert_called_once_with(self.expected_path) + + def test_isfile(self): + backend = PetrelBackend() + self.assertTrue(has_method(backend._client, 'contains')) + # raise Exception if `_client.contains` is not implemented + with delete_and_reset_method(backend._client, 'contains'): + self.assertFalse(has_method(backend._client, 'contains')) + with self.assertRaises(NotImplementedError): + backend.isfile(self.petrel_path) + + with patch.object( + backend._client, 'contains', + return_value=True) as patched_contains: + self.assertTrue(backend.isfile(self.petrel_path)) + patched_contains.assert_called_once_with(self.expected_path) + + def test_get_local_path(self): + backend = PetrelBackend() + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + patch.object(backend._client, 'contains', + return_value=True) as patch_contains: + with backend.get_local_path(self.petrel_path) as path: + self.assertTrue(osp.isfile(path)) + self.assertEqual(Path(path).open('rb').read(), b'petrel') + # exist the with block and path will be released + self.assertFalse(osp.isfile(path)) + patched_get.assert_called_once_with(self.expected_path) + patch_contains.assert_called_once_with(self.expected_path) + + def test_copyfile(self): + backend = PetrelBackend() + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + patch.object(backend._client, 'put') as patched_put, \ + patch.object(backend._client, 'isdir', return_value=False) as \ + patched_isdir: + src = self.petrel_path + dst = f'{self.petrel_dir}/test.bak.jpg' + expected_dst = f'{self.expected_dir}/test.bak.jpg' + self.assertEqual(backend.copyfile(src, dst), dst) + patched_get.assert_called_once_with(self.expected_path) + patched_put.assert_called_once_with(expected_dst, b'petrel') + patched_isdir.assert_called_once_with(expected_dst) + + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + patch.object(backend._client, 'put') as patched_put, \ + patch.object(backend._client, 'isdir', return_value=True) as \ + patched_isdir: + # dst is a directory + dst = f'{self.petrel_dir}/dir' + expected_dst = f'{self.expected_dir}/dir/test.jpg' + self.assertEqual(backend.copyfile(src, dst), f'{dst}/test.jpg') + patched_get.assert_called_once_with(self.expected_path) + patched_put.assert_called_once_with(expected_dst, b'petrel') + patched_isdir.assert_called_once_with( + f'{self.expected_dir}/dir') + + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + patch.object(backend._client, 'isdir', return_value=False) as \ + patched_isdir: + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(src, src) + + def test_copytree(self): + backend = PetrelBackend() + put_inputs = [] + get_inputs = [] + + def put(obj, filepath): + put_inputs.append((obj, filepath)) + + def get(filepath): + get_inputs.append(filepath) + + with build_temporary_directory() as tmp_dir, \ + patch.object(backend, 'put', side_effect=put),\ + patch.object(backend, 'get', side_effect=get),\ + patch.object(backend, 'exists', return_value=False): + tmp_dir = tmp_dir.replace('\\', '/') + dst = f'{tmp_dir}/dir' + self.assertEqual(backend.copytree(tmp_dir, dst), dst) + + self.assertEqual(len(put_inputs), 5) + self.assertEqual(len(get_inputs), 5) + + # dst should not exist + with patch.object(backend, 'exists', return_value=True): + with self.assertRaises(FileExistsError): + backend.copytree(dst, tmp_dir) + + def test_copyfile_from_local(self): + backend = PetrelBackend() + with patch.object(backend._client, 'put') as patched_put, \ + patch.object(backend._client, 'isdir', return_value=False) \ + as patched_isdir: + src = self.img_path + dst = f'{self.petrel_dir}/color.bak.jpg' + expected_dst = f'{self.expected_dir}/color.bak.jpg' + self.assertEqual(backend.copyfile_from_local(src, dst), dst) + patched_put.assert_called_once_with(expected_dst, + src.open('rb').read()) + patched_isdir.assert_called_once_with(expected_dst) + + with patch.object(backend._client, 'put') as patched_put, \ + patch.object(backend._client, 'isdir', return_value=True) as \ + patched_isdir: + # dst is a directory + src = self.img_path + dst = f'{self.petrel_dir}/dir' + expected_dst = f'{self.expected_dir}/dir/color.jpg' + self.assertEqual( + backend.copyfile_from_local(src, dst), f'{dst}/color.jpg') + patched_put.assert_called_once_with(expected_dst, + src.open('rb').read()) + patched_isdir.assert_called_once_with( + f'{self.expected_dir}/dir') + + def test_copytree_from_local(self): + backend = PetrelBackend() + inputs = [] + + def copyfile_from_local(src, dst): + inputs.append((src, dst)) + + with build_temporary_directory() as tmp_dir, \ + patch.object(backend, 'copyfile_from_local', + side_effect=copyfile_from_local),\ + patch.object(backend, 'exists', return_value=False): + backend.copytree_from_local(tmp_dir, self.petrel_dir) + + self.assertEqual(len(inputs), 5) + + # dst should not exist + with patch.object(backend, 'exists', return_value=True): + with self.assertRaises(FileExistsError): + backend.copytree_from_local(tmp_dir, self.petrel_dir) + + def test_copyfile_to_local(self): + backend = PetrelBackend() + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + tempfile.TemporaryDirectory() as tmp_dir: + src = self.petrel_path + dst = Path(tmp_dir) / 'test.bak.jpg' + self.assertEqual(backend.copyfile_to_local(src, dst), dst) + patched_get.assert_called_once_with(self.expected_path) + self.assertEqual(dst.open('rb').read(), b'petrel') + + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get, \ + tempfile.TemporaryDirectory() as tmp_dir: + # dst is a directory + src = self.petrel_path + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile_to_local(src, dst), dst / 'test.jpg') + patched_get.assert_called_once_with(self.expected_path) + self.assertEqual((dst / 'test.jpg').open('rb').read(), + b'petrel') + + def test_copytree_to_local(self): + backend = PetrelBackend() + inputs = [] + + def get(filepath): + inputs.append(filepath) + return b'petrel' + + with build_temporary_directory() as tmp_dir, \ + patch.object(backend, 'get', side_effect=get): + dst = f'{tmp_dir}/dir' + backend.copytree_to_local(tmp_dir, dst) + + self.assertEqual(len(inputs), 5) + + def test_remove(self): + backend = PetrelBackend() + self.assertTrue(has_method(backend._client, 'delete')) + # raise Exception if `delete` is not implemented + with delete_and_reset_method(backend._client, 'delete'): + self.assertFalse(has_method(backend._client, 'delete')) + with self.assertRaises(NotImplementedError): + backend.remove(self.petrel_path) + + with patch.object(backend._client, 'delete') as patched_delete, \ + patch.object(backend._client, 'isdir', return_value=False) \ + as patched_isdir, \ + patch.object(backend._client, 'contains', return_value=True) \ + as patched_contains: + backend.remove(self.petrel_path) + patched_delete.assert_called_once_with(self.expected_path) + patched_isdir.assert_called_once_with(self.expected_path) + patched_contains.assert_called_once_with(self.expected_path) + + def test_rmtree(self): + backend = PetrelBackend() + inputs = [] + + def remove(filepath): + inputs.append(filepath) + + with build_temporary_directory() as tmp_dir,\ + patch.object(backend, 'remove', side_effect=remove): + backend.rmtree(tmp_dir) + + self.assertEqual(len(inputs), 5) + + def test_copy_if_symlink_fails(self): + backend = PetrelBackend() + copyfile_inputs = [] + copytree_inputs = [] + + def copyfile(src, dst): + copyfile_inputs.append((src, dst)) + + def copytree(src, dst): + copytree_inputs.append((src, dst)) + + with patch.object(backend, 'copyfile', side_effect=copyfile), \ + patch.object(backend, 'isfile', return_value=True): + backend.copy_if_symlink_fails(self.petrel_path, 'path') + + self.assertEqual(len(copyfile_inputs), 1) + + with patch.object(backend, 'copytree', side_effect=copytree), \ + patch.object(backend, 'isfile', return_value=False): + backend.copy_if_symlink_fails(self.petrel_dir, 'path') + + self.assertEqual(len(copytree_inputs), 1) + + def test_list_dir_or_file(self): + backend = PetrelBackend() + + # raise Exception if `_client.list` is not implemented + self.assertTrue(has_method(backend._client, 'list')) + with delete_and_reset_method(backend._client, 'list'): + self.assertFalse(has_method(backend._client, 'list')) + with self.assertRaises(NotImplementedError): + list(backend.list_dir_or_file(self.petrel_dir)) + + with build_temporary_directory() as tmp_dir: + # list directories and files + self.assertEqual( + set(backend.list_dir_or_file(tmp_dir)), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + + # list directories and files recursively + self.assertEqual( + set(backend.list_dir_or_file(tmp_dir, recursive=True)), { + 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', + '/'.join(('dir2', 'dir3')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + # only list directories + self.assertEqual( + set(backend.list_dir_or_file(tmp_dir, list_file=False)), + {'dir1', 'dir2'}) + with self.assertRaisesRegex( + TypeError, + '`list_dir` should be False when `suffix` is not None' + ): + backend.list_dir_or_file( + tmp_dir, list_file=False, suffix='.txt') + + # only list directories recursively + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, list_file=False, recursive=True)), + {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) + + # only list files + self.assertEqual( + set(backend.list_dir_or_file(tmp_dir, list_dir=False)), + {'text1.txt', 'text2.txt'}) + + # only list files recursively + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, list_dir=False, recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix='.txt')), + {'text1.txt', 'text2.txt'}) + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix=('.txt', '.jpg'))), + {'text1.txt', 'text2.txt'}) + with self.assertRaisesRegex( + TypeError, + '`suffix` must be a string or tuple of strings'): + backend.list_dir_or_file( + tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + + # only list files ending with suffix recursively + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix='.txt', + recursive=True)), { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), + 'text1.txt', 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + def test_generate_presigned_url(self): + pass + +else: + + class TestPetrelBackend(TestCase): # type: ignore + + @classmethod + def setUpClass(cls): + cls.test_data_dir = Path(__file__).parent.parent.parent / 'data' + cls.local_img_path = cls.test_data_dir / 'color.jpg' + cls.local_img_shape = (300, 400, 3) + cls.petrel_dir = 'petrel://mmengine-test/data' + + def setUp(self): + backend = PetrelBackend() + backend.rmtree(self.petrel_dir) + with build_temporary_directory() as tmp_dir: + backend.copytree_from_local(tmp_dir, self.petrel_dir) + + text1_path = f'{self.petrel_dir}/text1.txt' + text2_path = f'{self.petrel_dir}/text2.txt' + text3_path = f'{self.petrel_dir}/dir1/text3.txt' + text4_path = f'{self.petrel_dir}/dir2/dir3/text4.txt' + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertTrue(backend.isfile(text1_path)) + self.assertTrue(backend.isfile(text2_path)) + self.assertTrue(backend.isfile(text3_path)) + self.assertTrue(backend.isfile(text4_path)) + self.assertTrue(backend.isfile(img_path)) + + def test_get(self): + backend = PetrelBackend() + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertEqual(backend.get(img_path), b'img') + + def test_get_text(self): + backend = PetrelBackend() + text_path = f'{self.petrel_dir}/text1.txt' + self.assertEqual(backend.get_text(text_path), 'text1') + + def test_put(self): + backend = PetrelBackend() + img_path = f'{self.petrel_dir}/img.jpg' + backend.put(b'img', img_path) + + def test_put_text(self): + backend = PetrelBackend() + text_path = f'{self.petrel_dir}/text5.txt' + backend.put_text('text5', text_path) + + def test_exists(self): + backend = PetrelBackend() + + # file and directory exist + dir_path = f'{self.petrel_dir}/dir2' + self.assertTrue(backend.exists(dir_path)) + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertTrue(backend.exists(img_path)) + + # file and directory does not exist + not_existed_dir = f'{self.petrel_dir}/not_existed_dir' + self.assertFalse(backend.exists(not_existed_dir)) + not_existed_path = f'{self.petrel_dir}/img.jpg' + self.assertFalse(backend.exists(not_existed_path)) + + def test_isdir(self): + backend = PetrelBackend() + dir_path = f'{self.petrel_dir}/dir2' + self.assertTrue(backend.isdir(dir_path)) + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertFalse(backend.isdir(img_path)) + + def test_isfile(self): + backend = PetrelBackend() + dir_path = f'{self.petrel_dir}/dir2' + self.assertFalse(backend.isfile(dir_path)) + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertTrue(backend.isfile(img_path)) + + def test_get_local_path(self): + backend = PetrelBackend() + img_path = f'{self.petrel_dir}/dir2/img.jpg' + with backend.get_local_path(img_path) as path: + self.assertTrue(osp.isfile(path)) + self.assertEqual(Path(path).open('rb').read(), b'img') + # exist the with block and path will be released + self.assertFalse(osp.isfile(path)) + + def test_copyfile(self): + backend = PetrelBackend() + + # dst is a file + src = f'{self.petrel_dir}/dir2/img.jpg' + dst = f'{self.petrel_dir}/img.jpg' + self.assertEqual(backend.copyfile(src, dst), dst) + self.assertTrue(backend.isfile(dst)) + + # dst is a directory + dst = f'{self.petrel_dir}/dir1' + expected_dst = f'{self.petrel_dir}/dir1/img.jpg' + self.assertEqual(backend.copyfile(src, dst), expected_dst) + self.assertTrue(backend.isfile(expected_dst)) + + # src and src should not be same file + with self.assertRaises(SameFileError): + backend.copyfile(src, src) + + def test_copytree(self): + backend = PetrelBackend() + src = f'{self.petrel_dir}/dir2' + dst = f'{self.petrel_dir}/dir3' + self.assertFalse(backend.exists(dst)) + self.assertEqual(backend.copytree(src, dst), dst) + self.assertEqual( + list(backend.list_dir_or_file(src)), + list(backend.list_dir_or_file(dst))) + + # dst should not exist + with self.assertRaises(FileExistsError): + backend.copytree(src, dst) + + def test_copyfile_from_local(self): + backend = PetrelBackend() + + # dst is a file + src = self.local_img_path + dst = f'{self.petrel_dir}/color.jpg' + self.assertFalse(backend.exists(dst)) + self.assertEqual(backend.copyfile_from_local(src, dst), dst) + self.assertTrue(backend.isfile(dst)) + + # dst is a directory + src = self.local_img_path + dst = f'{self.petrel_dir}/dir1' + expected_dst = f'{self.petrel_dir}/dir1/color.jpg' + self.assertFalse(backend.exists(expected_dst)) + self.assertEqual( + backend.copyfile_from_local(src, dst), expected_dst) + self.assertTrue(backend.isfile(expected_dst)) + + def test_copytree_from_local(self): + backend = PetrelBackend() + backend.rmtree(self.petrel_dir) + with build_temporary_directory() as tmp_dir: + backend.copytree_from_local(tmp_dir, self.petrel_dir) + files = backend.list_dir_or_file( + self.petrel_dir, recursive=True) + self.assertEqual(len(list(files)), 8) + + def test_copyfile_to_local(self): + backend = PetrelBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + # dst is a file + src = f'{self.petrel_dir}/dir2/img.jpg' + dst = Path(tmp_dir) / 'img.jpg' + self.assertEqual(backend.copyfile_to_local(src, dst), dst) + self.assertEqual(dst.open('rb').read(), b'img') + + # dst is a directory + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + self.assertEqual( + backend.copyfile_to_local(src, dst), dst / 'img.jpg') + self.assertEqual((dst / 'img.jpg').open('rb').read(), b'img') + + def test_copytree_to_local(self): + backend = PetrelBackend() + with tempfile.TemporaryDirectory() as tmp_dir: + backend.copytree_to_local(self.petrel_dir, tmp_dir) + self.assertTrue(osp.exists(Path(tmp_dir) / 'text1.txt')) + self.assertTrue(osp.exists(Path(tmp_dir) / 'dir2' / 'img.jpg')) + + def test_remove(self): + backend = PetrelBackend() + img_path = f'{self.petrel_dir}/dir2/img.jpg' + self.assertTrue(backend.isfile(img_path)) + backend.remove(img_path) + self.assertFalse(backend.exists(img_path)) + + def test_rmtree(self): + backend = PetrelBackend() + dir_path = f'{self.petrel_dir}/dir2' + self.assertTrue(backend.isdir(dir_path)) + backend.rmtree(dir_path) + self.assertFalse(backend.exists(dir_path)) + + def test_copy_if_symlink_fails(self): + backend = PetrelBackend() + + # dst is a file + src = f'{self.petrel_dir}/dir2/img.jpg' + dst = f'{self.petrel_dir}/img.jpg' + self.assertFalse(backend.exists(dst)) + self.assertFalse(backend.copy_if_symlink_fails(src, dst)) + self.assertTrue(backend.isfile(dst)) + + # dst is a directory + src = f'{self.petrel_dir}/dir2' + dst = f'{self.petrel_dir}/dir' + self.assertFalse(backend.exists(dst)) + self.assertFalse(backend.copy_if_symlink_fails(src, dst)) + self.assertTrue(backend.isdir(dst)) + + def test_list_dir_or_file(self): + backend = PetrelBackend() + + # list directories and files + self.assertEqual( + set(backend.list_dir_or_file(self.petrel_dir)), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + + # list directories and files recursively + self.assertEqual( + set(backend.list_dir_or_file(self.petrel_dir, recursive=True)), + { + 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join( + ('dir2', 'dir3')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + # only list directories + self.assertEqual( + set( + backend.list_dir_or_file(self.petrel_dir, + list_file=False)), + {'dir1', 'dir2'}) + with self.assertRaisesRegex( + TypeError, + '`list_dir` should be False when `suffix` is not None'): + backend.list_dir_or_file( + self.petrel_dir, list_file=False, suffix='.txt') + + # only list directories recursively + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, list_file=False, recursive=True)), + {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) + + # only list files + self.assertEqual( + set(backend.list_dir_or_file(self.petrel_dir, list_dir=False)), + {'text1.txt', 'text2.txt'}) + + # only list files recursively + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, list_dir=False, recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, list_dir=False, suffix='.txt')), + {'text1.txt', 'text2.txt'}) + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, + list_dir=False, + suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) + with self.assertRaisesRegex( + TypeError, + '`suffix` must be a string or tuple of strings'): + backend.list_dir_or_file( + self.petrel_dir, list_dir=False, suffix=['.txt', '.jpg']) + + # only list files ending with suffix recursively + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, + list_dir=False, + suffix='.txt', + recursive=True)), { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), 'text1.txt', + 'text2.txt' + }) + + # only list files ending with suffix + self.assertEqual( + set( + backend.list_dir_or_file( + self.petrel_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), '/'.join( + ('dir2', 'img.jpg')), 'text1.txt', 'text2.txt' + }) diff --git a/tests/test_fileio/test_fileclient.py b/tests/test_fileio/test_fileclient.py index 629bd7f6..3620ddb0 100644 --- a/tests/test_fileio/test_fileclient.py +++ b/tests/test_fileio/test_fileclient.py @@ -12,7 +12,7 @@ import cv2 import numpy as np import pytest -from mmengine import BaseStorageBackend, FileClient +from mmengine.fileio import BaseStorageBackend, FileClient from mmengine.utils import has_method sys.modules['ceph'] = MagicMock() @@ -354,9 +354,15 @@ class TestFileClient: petrel_backend.remove(petrel_path) with patch.object(petrel_backend.client._client, - 'delete') as mock_delete: + 'delete') as mock_delete, \ + patch.object(petrel_backend.client._client, + 'isdir', return_value=False) as mock_isdir, \ + patch.object(petrel_backend.client._client, + 'contains', return_value=True) as mock_contains: petrel_backend.remove(petrel_path) mock_delete.assert_called_once_with(petrel_path) + mock_isdir.assert_called_once_with(petrel_path) + mock_contains.assert_called_once_with(petrel_path) # test `exists` assert has_method(petrel_backend.client._client, 'contains') diff --git a/tests/test_fileio/test_fileio.py b/tests/test_fileio/test_fileio.py index 3077a948..33a0956f 100644 --- a/tests/test_fileio/test_fileio.py +++ b/tests/test_fileio/test_fileio.py @@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch import pytest import mmengine -from mmengine.fileio.file_client import HTTPBackend, PetrelBackend +from mmengine.fileio import HTTPBackend, PetrelBackend sys.modules['petrel_client'] = MagicMock() sys.modules['petrel_client.client'] = MagicMock() @@ -151,22 +151,26 @@ def test_list_from_file(): assert filelist == ['4.jpg', '5.jpg'] # get list from http + filename = 'http://path/of/your/file' with patch.object( HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): - filename = 'http://path/of/your/file' filelist = mmengine.list_from_file( filename, file_client_args={'backend': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file( filename, file_client_args={'prefix': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmengine.list_from_file( + filename, backend_args={'backend': 'http'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] # get list from petrel + filename = 's3://path/of/your/file' with patch.object( PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): - filename = 's3://path/of/your/file' filelist = mmengine.list_from_file( filename, file_client_args={'backend': 'petrel'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] @@ -175,6 +179,9 @@ def test_list_from_file(): assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] + filelist = mmengine.list_from_file( + filename, backend_args={'backend': 'petrel'}) + assert filelist == ['1.jpg', '2.jpg', '3.jpg'] def test_dict_from_file(): @@ -186,28 +193,36 @@ def test_dict_from_file(): assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'} # get dict from http + filename = 'http://path/of/your/file' with patch.object( HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): - filename = 'http://path/of/your/file' mapping = mmengine.dict_from_file( filename, file_client_args={'backend': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file( filename, file_client_args={'prefix': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmengine.dict_from_file( + filename, backend_args={'backend': 'http'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} # get dict from petrel + filename = 's3://path/of/your/file' with patch.object( PetrelBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): - filename = 's3://path/of/your/file' mapping = mmengine.dict_from_file( filename, file_client_args={'backend': 'petrel'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file( filename, file_client_args={'prefix': 's3'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} + mapping = mmengine.dict_from_file( + filename, backend_args={'backend': 'petrel'}) + assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py new file mode 100644 index 00000000..1e8698cc --- /dev/null +++ b/tests/test_fileio/test_io.py @@ -0,0 +1,532 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import platform +import sys +import tempfile +from contextlib import contextmanager +from pathlib import Path +from shutil import SameFileError +from unittest.mock import MagicMock, patch + +import pytest + +import mmengine.fileio as fileio + +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() + +test_data_dir = Path(__file__).parent.parent / 'data' +text_path = test_data_dir / 'filelist.txt' +img_path = test_data_dir / 'color.jpg' +img_url = 'https://raw.githubusercontent.com/mmengine/tests/data/img.png' + + +@contextmanager +def build_temporary_directory(): + """Build a temporary directory containing many files to test + ``FileClient.list_dir_or_file``. + + . \n + | -- dir1 \n + | -- | -- text3.txt \n + | -- dir2 \n + | -- | -- dir3 \n + | -- | -- | -- text4.txt \n + | -- | -- img.jpg \n + | -- text1.txt \n + | -- text2.txt \n + """ + with tempfile.TemporaryDirectory() as tmp_dir: + text1 = Path(tmp_dir) / 'text1.txt' + text1.open('w').write('text1') + text2 = Path(tmp_dir) / 'text2.txt' + text2.open('w').write('text2') + dir1 = Path(tmp_dir) / 'dir1' + dir1.mkdir() + text3 = dir1 / 'text3.txt' + text3.open('w').write('text3') + dir2 = Path(tmp_dir) / 'dir2' + dir2.mkdir() + jpg1 = dir2 / 'img.jpg' + jpg1.open('wb').write(b'img') + dir3 = dir2 / 'dir3' + dir3.mkdir() + text4 = dir3 / 'text4.txt' + text4.open('w').write('text4') + yield tmp_dir + + +def test_parse_uri_prefix(): + # input path is None + with pytest.raises(AssertionError): + fileio.io._parse_uri_prefix(None) + + # input path is list + with pytest.raises(AssertionError): + fileio.io._parse_uri_prefix([]) + + # input path is Path object + assert fileio.io._parse_uri_prefix(uri=text_path) == '' + + # input path starts with https + assert fileio.io._parse_uri_prefix(uri=img_url) == 'https' + + # input path starts with s3 + uri = 's3://your_bucket/img.png' + assert fileio.io._parse_uri_prefix(uri) == 's3' + + # input path starts with clusterName:s3 + uri = 'clusterName:s3://your_bucket/img.png' + assert fileio.io._parse_uri_prefix(uri) == 's3' + + +def test_get_file_backend(): + # other unit tests may have added instances so clear them here. + fileio.io.backend_instances = {} + + # uri should not be None when "backend" does not exist in backend_args + with pytest.raises(ValueError, match='uri should not be None'): + fileio.get_file_backend(None, backend_args=None) + + # uri is not None + backend = fileio.get_file_backend(uri=text_path) + assert isinstance(backend, fileio.backends.LocalBackend) + + uri = 'petrel://your_bucket/img.png' + backend = fileio.get_file_backend(uri=uri) + assert isinstance(backend, fileio.backends.PetrelBackend) + + backend = fileio.get_file_backend(uri=img_url) + assert isinstance(backend, fileio.backends.HTTPBackend) + uri = 'http://raw.githubusercontent.com/mmengine/tests/data/img.png' + backend = fileio.get_file_backend(uri=uri) + assert isinstance(backend, fileio.backends.HTTPBackend) + + # backend_args is not None and it contains a backend name + backend_args = {'backend': 'local'} + backend = fileio.get_file_backend(uri=None, backend_args=backend_args) + assert isinstance(backend, fileio.backends.LocalBackend) + + backend_args = {'backend': 'petrel', 'enable_mc': True} + backend = fileio.get_file_backend(uri=None, backend_args=backend_args) + assert isinstance(backend, fileio.backends.PetrelBackend) + + # backend name has a higher priority + backend_args = {'backend': 'http'} + backend = fileio.get_file_backend(uri=text_path, backend_args=backend_args) + assert isinstance(backend, fileio.backends.HTTPBackend) + + # test enable_singleton parameter + assert len(fileio.io.backend_instances) == 0 + backend1 = fileio.get_file_backend(uri=text_path, enable_singleton=True) + assert isinstance(backend1, fileio.backends.LocalBackend) + assert len(fileio.io.backend_instances) == 1 + assert fileio.io.backend_instances[':{}'] is backend1 + + backend2 = fileio.get_file_backend(uri=text_path, enable_singleton=True) + assert isinstance(backend2, fileio.backends.LocalBackend) + assert len(fileio.io.backend_instances) == 1 + assert backend2 is backend1 + + backend3 = fileio.get_file_backend(uri=text_path, enable_singleton=False) + assert isinstance(backend3, fileio.backends.LocalBackend) + assert len(fileio.io.backend_instances) == 1 + assert backend3 is not backend2 + + backend_args = {'path_mapping': {'src': 'dst'}, 'enable_mc': True} + uri = 'petrel://your_bucket/img.png' + backend4 = fileio.get_file_backend( + uri=uri, backend_args=backend_args, enable_singleton=True) + assert isinstance(backend4, fileio.backends.PetrelBackend) + assert len(fileio.io.backend_instances) == 2 + unique_key = 'petrel:{"path_mapping": {"src": "dst"}, "enable_mc": true}' + assert fileio.io.backend_instances[unique_key] is backend4 + assert backend4 is not backend2 + + uri = 'petrel://your_bucket/img1.png' + backend5 = fileio.get_file_backend( + uri=uri, backend_args=backend_args, enable_singleton=True) + assert isinstance(backend5, fileio.backends.PetrelBackend) + assert len(fileio.io.backend_instances) == 2 + assert backend5 is backend4 + assert backend5 is not backend2 + + backend_args = {'path_mapping': {'src1': 'dst1'}, 'enable_mc': True} + backend6 = fileio.get_file_backend( + uri=uri, backend_args=backend_args, enable_singleton=True) + assert isinstance(backend6, fileio.backends.PetrelBackend) + assert len(fileio.io.backend_instances) == 3 + unique_key = 'petrel:{"path_mapping": {"src1": "dst1"}, "enable_mc": true}' + assert fileio.io.backend_instances[unique_key] is backend6 + assert backend6 is not backend4 + assert backend6 is not backend5 + + backend7 = fileio.get_file_backend( + uri=uri, backend_args=backend_args, enable_singleton=False) + assert isinstance(backend7, fileio.backends.PetrelBackend) + assert len(fileio.io.backend_instances) == 3 + assert backend7 is not backend6 + + +def test_get(): + # test LocalBackend + filepath = Path(img_path) + img_bytes = fileio.get(filepath) + assert filepath.open('rb').read() == img_bytes + + +def test_get_text(): + # test LocalBackend + filepath = Path(text_path) + text = fileio.get_text(filepath) + assert filepath.open('r').read() == text + + +def test_put(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = Path(tmp_dir) / 'img.png' + fileio.put(b'disk', filepath) + assert fileio.get(filepath) == b'disk' + + # If the directory does not exist, put will create a + # directory first + filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.jpg' + fileio.put(b'disk', filepath) + assert fileio.get(filepath) == b'disk' + + +def test_put_text(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = Path(tmp_dir) / 'text.txt' + fileio.put_text('text', filepath) + assert fileio.get_text(filepath) == 'text' + + # If the directory does not exist, put_text will create a + # directory first + filepath = Path(tmp_dir) / 'not_existed_dir' / 'test.txt' + fileio.put_text('disk', filepath) + assert fileio.get_text(filepath) == 'disk' + + +def test_exists(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + assert fileio.exists(tmp_dir) + filepath = Path(tmp_dir) / 'test.txt' + assert not fileio.exists(filepath) + fileio.put_text('disk', filepath) + assert fileio.exists(filepath) + + +def test_isdir(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + assert fileio.isdir(tmp_dir) + filepath = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', filepath) + assert not fileio.isdir(filepath) + + +def test_isfile(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + assert not fileio.isfile(tmp_dir) + filepath = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', filepath) + assert fileio.isfile(filepath) + + +def test_join_path(): + # test LocalBackend + filepath = fileio.join_path(test_data_dir, 'file') + expected = osp.join(test_data_dir, 'file') + assert filepath == expected + + filepath = fileio.join_path(test_data_dir, 'dir', 'file') + expected = osp.join(test_data_dir, 'dir', 'file') + assert filepath == expected + + +def test_get_local_path(): + # test LocalBackend + with fileio.get_local_path(text_path) as filepath: + assert str(text_path) == filepath + + +def test_copyfile(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + assert fileio.copyfile(src, dst) == dst + assert fileio.get_text(dst) == 'disk' + + # dst is a directory + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') + assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + + # src and src should not be same file + with pytest.raises(SameFileError): + fileio.copyfile(src, src) + + +def test_copytree(): + # test LocalBackend + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + assert fileio.copytree(src, dst) == dst + assert fileio.isdir(dst) + assert fileio.isfile(dst / 'text3.txt') + assert fileio.get_text(dst / 'text3.txt') == 'text3' + + # dst should not exist + with pytest.raises(FileExistsError): + fileio.copytree(src, Path(tmp_dir) / 'dir2') + + +def test_copyfile_from_local(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + assert fileio.copyfile(src, dst) == dst + assert fileio.get_text(dst) == 'disk' + + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') + assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + + # src and src should not be same file + with pytest.raises(SameFileError): + fileio.copyfile(src, src) + + +def test_copytree_from_local(): + # test LocalBackend + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + assert fileio.copytree(src, dst) == dst + assert fileio.isdir(dst) + assert fileio.isfile(dst / 'text3.txt') + assert fileio.get_text(dst / 'text3.txt') == 'text3' + + # dst should not exist + with pytest.raises(FileExistsError): + fileio.copytree(src, Path(tmp_dir) / 'dir2') + + +def test_copyfile_to_local(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + src = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', src) + dst = Path(tmp_dir) / 'test.txt.bak' + assert fileio.copyfile(src, dst) == dst + assert fileio.get_text(dst) == 'disk' + + dst = Path(tmp_dir) / 'dir' + dst.mkdir() + assert fileio.copyfile(src, dst) == fileio.join_path(dst, 'test.txt') + assert fileio.get_text(fileio.join_path(dst, 'test.txt')) == 'disk' + + # src and src should not be same file + with pytest.raises(SameFileError): + fileio.copyfile(src, src) + + +def test_copytree_to_local(): + # test LocalBackend + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + src = Path(tmp_dir) / 'dir1' + dst = Path(tmp_dir) / 'dir100' + assert fileio.copytree(src, dst) == dst + assert fileio.isdir(dst) + assert fileio.isfile(dst / 'text3.txt') + assert fileio.get_text(dst / 'text3.txt') == 'text3' + + # dst should not exist + with pytest.raises(FileExistsError): + fileio.copytree(src, Path(tmp_dir) / 'dir2') + + +def test_remove(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + # filepath is a Path object + filepath = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', filepath) + assert fileio.exists(filepath) + fileio.remove(filepath) + assert not fileio.exists(filepath) + + # raise error if file does not exist + with pytest.raises(FileNotFoundError): + filepath = Path(tmp_dir) / 'test1.txt' + fileio.remove(filepath) + + # can not remove directory + filepath = Path(tmp_dir) / 'dir' + filepath.mkdir() + with pytest.raises(IsADirectoryError): + fileio.remove(filepath) + + +def test_rmtree(): + # test LocalBackend + with build_temporary_directory() as tmp_dir: + # src and dst are Path objects + dir_path = Path(tmp_dir) / 'dir1' + assert fileio.exists(dir_path) + fileio.rmtree(dir_path) + assert not fileio.exists(dir_path) + + dir_path = Path(tmp_dir) / 'dir2' + assert fileio.exists(dir_path) + fileio.rmtree(dir_path) + assert not fileio.exists(dir_path) + + +def test_copy_if_symlink_fails(): + # test LocalBackend + with tempfile.TemporaryDirectory() as tmp_dir: + # create a symlink for a file + src = Path(tmp_dir) / 'test.txt' + fileio.put_text('disk', src) + dst = Path(tmp_dir) / 'test_link.txt' + res = fileio.copy_if_symlink_fails(src, dst) + if platform.system() == 'Linux': + assert res + assert osp.islink(dst) + assert fileio.get_text(dst) == 'disk' + + # create a symlink for a directory + src = Path(tmp_dir) / 'dir' + src.mkdir() + dst = Path(tmp_dir) / 'dir_link' + res = fileio.copy_if_symlink_fails(src, dst) + if platform.system() == 'Linux': + assert res + assert osp.islink(dst) + assert fileio.exists(dst) + + def symlink(src, dst): + raise Exception + + # copy files if symblink fails + with patch.object(os, 'symlink', side_effect=symlink): + src = Path(tmp_dir) / 'test.txt' + dst = Path(tmp_dir) / 'test_link1.txt' + res = fileio.copy_if_symlink_fails(src, dst) + assert not res + assert not osp.islink(dst) + assert fileio.exists(dst) + + # copy directory if symblink fails + with patch.object(os, 'symlink', side_effect=symlink): + src = Path(tmp_dir) / 'dir' + dst = Path(tmp_dir) / 'dir_link1' + res = fileio.copy_if_symlink_fails(src, dst) + assert not res + assert not osp.islink(dst) + assert fileio.exists(dst) + + +def test_list_dir_or_file(): + # test LocalBackend + with build_temporary_directory() as tmp_dir: + # list directories and files + assert set(fileio.list_dir_or_file(tmp_dir)) == { + 'dir1', 'dir2', 'text1.txt', 'text2.txt' + } + + # list directories and files recursively + assert set(fileio.list_dir_or_file(tmp_dir, recursive=True)) == { + 'dir1', + osp.join('dir1', 'text3.txt'), 'dir2', + osp.join('dir2', 'dir3'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + } + + # only list directories + assert set(fileio.list_dir_or_file( + tmp_dir, list_file=False)) == {'dir1', 'dir2'} + + with pytest.raises( + TypeError, + match='`suffix` should be None when `list_dir` is True'): + list( + fileio.list_dir_or_file( + tmp_dir, list_file=False, suffix='.txt')) + + # only list directories recursively + assert set( + fileio.list_dir_or_file( + tmp_dir, list_file=False, + recursive=True)) == {'dir1', 'dir2', + osp.join('dir2', 'dir3')} + + # only list files + assert set(fileio.list_dir_or_file( + tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} + + # only list files recursively + assert set( + fileio.list_dir_or_file(tmp_dir, list_dir=False, + recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), + 'text1.txt', 'text2.txt' + } + + # only list files ending with suffix + assert set( + fileio.list_dir_or_file( + tmp_dir, list_dir=False, + suffix='.txt')) == {'text1.txt', 'text2.txt'} + assert set( + fileio.list_dir_or_file( + tmp_dir, list_dir=False, + suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} + + with pytest.raises( + TypeError, + match='`suffix` must be a string or tuple of strings'): + list( + fileio.list_dir_or_file( + tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])) + + # only list files ending with suffix recursively + assert set( + fileio.list_dir_or_file( + tmp_dir, list_dir=False, suffix='.txt', recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + } + + # only list files ending with suffix + assert set( + fileio.list_dir_or_file( + tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' + } diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index deb01992..8fbb1a56 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest import torch @@ -9,6 +9,7 @@ import torch.nn as nn from torch.utils.data import Dataset from mmengine.evaluator import BaseMetric +from mmengine.fileio import FileClient, LocalBackend from mmengine.hooks import CheckpointHook from mmengine.logging import MessageHub from mmengine.model import BaseModel @@ -71,34 +72,41 @@ class TriangleMetric(BaseMetric): return dict(acc=acc) -class MockPetrel: - - _allow_symlink = False - - def __init__(self): - pass - - @property - def name(self): - return self.__class__.__name__ - - @property - def allow_symlink(self): - return self._allow_symlink - - -prefix_to_backends = {'s3': MockPetrel} - - class TestCheckpointHook: - @patch('mmengine.fileio.file_client.FileClient._prefix_to_backends', - prefix_to_backends) + def test_init(self, tmp_path): + # Test file_client_args and backend_args + with pytest.warns( + DeprecationWarning, + match='"file_client_args" will be deprecated in future'): + CheckpointHook(file_client_args={'backend': 'disk'}) + + with pytest.raises( + ValueError, + match='"file_client_args" and "backend_args" cannot be set ' + 'at the same time'): + CheckpointHook( + file_client_args={'backend': 'disk'}, + backend_args={'backend': 'local'}) + def test_before_train(self, tmp_path): runner = Mock() work_dir = str(tmp_path) runner.work_dir = work_dir + # file_client_args is None + checkpoint_hook = CheckpointHook() + checkpoint_hook.before_train(runner) + assert isinstance(checkpoint_hook.file_client, FileClient) + assert isinstance(checkpoint_hook.file_backend, LocalBackend) + + # file_client_args is not None + checkpoint_hook = CheckpointHook(file_client_args={'backend': 'disk'}) + checkpoint_hook.before_train(runner) + assert isinstance(checkpoint_hook.file_client, FileClient) + # file_backend is the alias of file_client + assert checkpoint_hook.file_backend is checkpoint_hook.file_client + # the out_dir of the checkpoint hook is None checkpoint_hook = CheckpointHook(interval=1, by_epoch=True) checkpoint_hook.before_train(runner) @@ -392,6 +400,26 @@ class TestCheckpointHook: assert (runner.epoch + 1) % 2 == 0 assert not os.path.exists(f'{work_dir}/epoch_8.pth') + # save_checkpoint of runner should be called with expected arguments + runner = Mock() + work_dir = str(tmp_path) + runner.work_dir = tmp_path + runner.epoch = 1 + runner.message_hub = MessageHub.get_instance('test_after_train_epoch2') + + checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) + checkpoint_hook.before_train(runner) + checkpoint_hook.after_train_epoch(runner) + + runner.save_checkpoint.assert_called_once_with( + runner.work_dir, + 'epoch_2.pth', + None, + backend_args=None, + by_epoch=True, + save_optimizer=True, + save_param_scheduler=True) + def test_after_train_iter(self, tmp_path): work_dir = str(tmp_path) runner = Mock() diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index 230355cc..3a3ddb37 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -26,6 +26,22 @@ class TestLoggerHook: with pytest.raises(ValueError): LoggerHook(file_client_args=dict(enable_mc=True)) + # test `file_client_args` and `backend_args` + with pytest.warns( + DeprecationWarning, + match='"file_client_args" will be deprecated in future'): + logger_hook = LoggerHook( + out_dir='tmp.txt', file_client_args={'backend': 'disk'}) + + with pytest.raises( + ValueError, + match='"file_client_args" and "backend_args" cannot be ' + 'set at the same time'): + logger_hook = LoggerHook( + out_dir='tmp.txt', + file_client_args={'backend': 'disk'}, + backend_args={'backend': 'local'}) + def test_before_run(self): runner = MagicMock() runner.iter = 10 -- GitLab