Skip to content
Snippets Groups Projects
Unverified Commit b5fd38ab authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance]: make total loss at the end of all losses (#369)

parent 8f91cf0b
No related branches found
No related tags found
No related merge requests found
......@@ -159,20 +159,23 @@ class BaseModel(BaseModule):
all losses, and the second is log_vars which will be sent to the
logger.
"""
log_vars = OrderedDict()
log_vars = []
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
log_vars.append([loss_name, loss_value.mean()])
elif is_list_of(loss_value, torch.Tensor):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
log_vars.append(
[loss_name,
sum(_loss.mean() for _loss in loss_value)])
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(value for key, value in log_vars.items() if 'loss' in key)
log_vars['loss'] = loss
loss = sum(value for key, value in log_vars if 'loss' in key)
log_vars.insert(0, ['loss', loss])
log_vars = OrderedDict(log_vars) # type: ignore
return loss, log_vars
return loss, log_vars # type: ignore
def to(self,
device: Optional[Union[int, str, torch.device]] = None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment