Skip to content
Snippets Groups Projects
Unverified Commit 12f7d3a0 authored by Jiazhen Wang's avatar Jiazhen Wang Committed by GitHub
Browse files

[Fix]: fix load_checkpoint (#332)

parent 2994195b
No related branches found
No related tags found
No related merge requests found
...@@ -1806,13 +1806,8 @@ class Runner: ...@@ -1806,13 +1806,8 @@ class Runner:
Defaults to 'default'. Defaults to 'default'.
""" """
if map_location == 'default': if map_location == 'default':
if torch.cuda.is_available(): device = get_device()
device = get_device() checkpoint = self.load_checkpoint(filename, map_location=device)
checkpoint = self.load_checkpoint(
filename,
map_location=lambda storage, loc: storage.to(device))
else:
checkpoint = self.load_checkpoint(filename)
else: else:
checkpoint = self.load_checkpoint( checkpoint = self.load_checkpoint(
filename, map_location=map_location) filename, map_location=map_location)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment