相关文章推荐
谦和的猕猴桃  ·  商品列表筛选 ...·  1 年前    · 
暗恋学妹的煎饼  ·  PostgreS ...·  1 年前    · 
豪情万千的松鼠  ·  【Java ...·  1 年前    · 
Python 3.11.3 (main, Apr  7 2023, 19:25:52) [Clang 14.0.0 (clang-1400.0.29.202)] on darwin
Type "help", "copyright", "credits" or "license" for more information.

正常情况下,我们会使用 torch.save 保存模型的 state_dict 。但我们也可以 torch.save 保存一个自定义类型对象,例如

import torch
import torch.nn as nn
class Module(nn.Module):
    def __init__(self) -> None:
        self._one = 1
torch.save(Module(), 'module.pth')

在读取 module.pth 时,可能会遇到 AttributeError

import torch
torch.load('module.pth')
Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 3, in <module>
    torch.load('module.pth')
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_class
    return super().find_class(mod_name, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'Module' on <module '__main__' from '/Users/bytedance/Developer/todd/load.py'>

这是因为 torch.save 底层通过 pickle 实现,而 pickle 在保存自定义类型对象时不会保存其类型定义。用户需要保证 torch.load 时,自定义类型可访问,以便构造被保存的对象。也就是说,如果我们将 Module 引用到当前命名空间,就可以正常加载 module.pth

import torch
from save import Module
torch.load('module.pth')

但是有些情况下,我们无法访问某些自定义类型,也不希望恢复被保存的对象,只想知道被保存的对象存储了哪些数据,可以用下面的方法

import torch
class Module:
    def __init__(self) -> None:
        # in case __setstate__ is not called
        self._state = None
    def __setstate__(self, state):
        # whenever state is not empty, __setstate__ is called
        self._state = state
module = torch.load('module.pth')
print(module._state)
{'_one': 1}

但是如果自定义类型是从其他位置 import 得到的,例如

# module.py
import torch.nn as nn
class Module(nn.Module):
    def __init__(self) -> None:
        self._one = 1
# save.py
import torch
from module import Module
torch.save(Module(), 'module.pth')

torch.load 会先尝试 import 相应的模块,如果不存在就会报错

Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 13, in <module>
    module = torch.load('module.pth')
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_class
    return super().find_class(mod_name, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'module'

我们可以 mock 相应模块

import sys
from unittest.mock import Mock
import torch
sys.modules['module'] = Mock()
torch.load('module.pth')
Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 14, in <module>
    module = torch.load('module.pth')
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
_pickle.UnpicklingError: NEWOBJ class argument must be a type, not Mock

出现这个问题,是因为 Mock 具有递归创建的特性。我们可以手动修改

import sys
from unittest.mock import Mock
import torch
class Module:
    def __init__(self) -> None:
        self._state = None
    def __setstate__(self, state):
        self._state = state
sys.modules['module'] = Mock()
sys.modules['module'].Module = Module
module = torch.load('module.pth')
print(module._state)
{'_one': 1}
                    但是有些情况下,我们无法访问某些自定义类型,也不希望恢复被保存的对象,只想知道被保存的对象存储了哪些数据,可以用下面的方法。时,自定义类型可访问,以便构造被保存的对象。也就是说,如果我们将。在保存自定义类型对象时不会保存其类型定义。时,可能会遇到 AttributeError。引用到当前命名空间,就可以正常加载。具有递归创建的特性。但是如果自定义类型是从其他位置。相应的模块,如果不存在就会报错。保存一个自定义类型对象,例如。正常情况下,我们会使用。出现这个问题,是因为。
				
对于如何自定义评估方法,也就是重写Metrics的子类。 官方文档和TorchMetrics:PyTorch的指标度量库都写的比较完备了。 针对第二篇文章中给出的样例,在实现后得到object has no attribute '_defaults’报错。原因在于没有初始化父类。 加上super().__init__()就可以了
cuDNN使用非确定性算法,并且可以使用torch.backends.cudnn.enabled = False来进行禁用 如果设置为torch.backends.cudnn.enabled =True,说明设置为使用使用非确定性算法 然后再设置: torch.backends.cudnn.benchmark = true 那么cuDNN使用的非确定性算法就会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题 一般来讲,应该遵循以下准则: 如果网络的输入数据维度或类型上变化不大,设置  torch.backends.cudnn.benchmark = true  可以增加运行效率;
AttributeError: ‘...’ object has no attribute ‘copy’ AttributeError: ‘...’ object has no attribute ‘module’ 引发的原因可能是: 使用了torch.save(model)保存模型参数,但使用model.torch.load()...
torch.load()加载模型时出现如下错误 Traceback (most recent call last): File "demo_syncnet.py", line 26, in <module> s.loadParameters(opt.initial_model); File "/media/cj/75bb371d-0b6d-4995-bb72-060a4...
*** FileNotFoundError: [Errno 2] No such file or directory: '~/DeepLearning/.../XXXXXnet.pth' 对应的代码为: weight_path = '~/DeepLearning/......' save_path = os.path.join(weight_path, save_filename) 错误原因: 这里是python不是linux,~在Python里不能直接代表主目录,所以要用绝对路径。
torch.load 出现 AttributeError: Can't get attribute 'Net' on module '__main__'问题解决方案 最近,将已经训练好的模型保存下来后,通过torch.load(model_path)方法读取时,发现没办法正常运行,抛出如下错误: AttributeError: Can't get attribute 'Net' on module '__main__' 我直接好家伙,骂骂咧咧去搜为啥。 报错原因: torch.load()方法所
我这里是升级torchvision后问题解决了 pip install torch==0.4.1 torchvision==0.4.0 -i https://pypi.tuna.tsinghua.edu.cn/simple/ python = 3.6 torch = 0.4.1 torchvision = 0.4.0 今天碰到一个怪问题,明明各种包都已经安装好了,进入python也可以正常使用pytorch,但一进入ipython, jupyter notebook就无法使用pytorch, &gt;&gt;&gt;import torch as t ModuleNotFoundError: No module named 'torch' 事发突然,不知何故,硬着头皮重新安装 $ co...
根据引用[1]和引用[2]的内容,你遇到的报错可能是因为你尝试使用torch.load()加载一个不是由torch.save()保存的对象。torch.load()是用来加载由torch.save()存储的对象的方法。它使用Python的unpickling工具来处理存储的对象。如果你尝试加载一个不是由torch.save()保存的对象,就会引发异常。 为了解决这个问题,你可以尝试以下方法: 1. 确保你使用torch.save()正确保存了对象。你可以使用torch.save(model, 'save.pt')来保存整个模型,或者使用torch.save(model.state_dict(), 'save.pt')来保存训练好的权重。 2. 确保你使用torch.load()加载的是由torch.save()保存的对象。你可以使用torch.load('save.pt')来加载整个模型,或者使用model.load_state_dict(torch.load("save.pt"))来加载训练好的权重。 希望这些方法能够帮助你解决torch.load报错的问题。如果问题仍然存在,你可以尝试在错误处向前溯源打断点,并逐步进行调试。
LutingWang: 不一定,报错中的 _thread.RLock 只是表示不可序列化的对象是一个 _thread.RLock 类型的对象。如果报错信息中的类型不是 _thread.RLock,那么就需要递归查找对应类型的属性。比如我们用 pickle.dump 去存储一个 open 返回的文件句柄时,就会出现这样的报错信息 TypeError: cannot pickle '_io.TextIOWrapper' object