Skip to content
Snippets Groups Projects
Unverified Commit a4f5533d authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

fix torch 1.10 amp error (#330)

parent 2b8a32ec
No related branches found
No related tags found
No related merge requests found
......@@ -77,7 +77,20 @@ def autocast(enabled: bool = True, **kwargs):
'If pytorch versions is between 1.5.0 and 1.10, '
'`autocast` is only available in gpu mode')
elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >=
digit_version('1.10.0')):
if torch.cuda.is_available():
kwargs.setdefault('device_type', 'cuda')
else:
kwargs.setdefault('device_type', 'cpu')
# torch.autocast only support `dtype=torch.bfloat16` in
# pytorch 1.10
kwargs.setdefault('dtype', torch.bfloat16)
with torch.autocast(enabled=enabled, **kwargs):
yield
elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
if torch.cuda.is_available():
kwargs.setdefault('device_type', 'cuda')
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment