diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index fec759bef55c0094f62cde697a38aacd308bd2d4..f754e88c8f98f2407031849f930aa0557c1cef49 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1806,13 +1806,8 @@ class Runner: Defaults to 'default'. """ if map_location == 'default': - if torch.cuda.is_available(): - device = get_device() - checkpoint = self.load_checkpoint( - filename, - map_location=lambda storage, loc: storage.to(device)) - else: - checkpoint = self.load_checkpoint(filename) + device = get_device() + checkpoint = self.load_checkpoint(filename, map_location=device) else: checkpoint = self.load_checkpoint( filename, map_location=map_location)