Skip to content
Snippets Groups Projects
Unverified Commit 2df5bc13 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix base tta model (#593)


Co-authored-by: default avatarubuntu <ubuntu@localhost.localdomain>
parent 46add351
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ MergedDataSamples = List[BaseDataElement] ...@@ -20,7 +20,7 @@ MergedDataSamples = List[BaseDataElement]
@MODELS.register_module() @MODELS.register_module()
class BaseTTAModel: class BaseTTAModel(nn.Module):
"""Base model for inference with test-time augmentation. """Base model for inference with test-time augmentation.
``BaseTTAModel`` is a wrapper for inference given multi-batch data. ``BaseTTAModel`` is a wrapper for inference given multi-batch data.
...@@ -74,6 +74,7 @@ class BaseTTAModel: ...@@ -74,6 +74,7 @@ class BaseTTAModel:
""" """
def __init__(self, module: Union[dict, nn.Module]): def __init__(self, module: Union[dict, nn.Module]):
super().__init__()
if isinstance(module, nn.Module): if isinstance(module, nn.Module):
self.module = module self.module = module
elif isinstance(module, dict): elif isinstance(module, dict):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment