Skip to content
Snippets Groups Projects
Unverified Commit b5fd38ab authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance]: make total loss at the end of all losses (#369)

parent 8f91cf0b
No related branches found
No related tags found
No related merge requests found
...@@ -159,20 +159,23 @@ class BaseModel(BaseModule): ...@@ -159,20 +159,23 @@ class BaseModel(BaseModule):
all losses, and the second is log_vars which will be sent to the all losses, and the second is log_vars which will be sent to the
logger. logger.
""" """
log_vars = OrderedDict() log_vars = []
for loss_name, loss_value in losses.items(): for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor): if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean() log_vars.append([loss_name, loss_value.mean()])
elif is_list_of(loss_value, torch.Tensor): elif is_list_of(loss_value, torch.Tensor):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) log_vars.append(
[loss_name,
sum(_loss.mean() for _loss in loss_value)])
else: else:
raise TypeError( raise TypeError(
f'{loss_name} is not a tensor or list of tensors') f'{loss_name} is not a tensor or list of tensors')
loss = sum(value for key, value in log_vars.items() if 'loss' in key) loss = sum(value for key, value in log_vars if 'loss' in key)
log_vars['loss'] = loss log_vars.insert(0, ['loss', loss])
log_vars = OrderedDict(log_vars) # type: ignore
return loss, log_vars return loss, log_vars # type: ignore
def to(self, def to(self,
device: Optional[Union[int, str, torch.device]] = None, device: Optional[Union[int, str, torch.device]] = None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment