项目场景:pytorch torch.load ModuleNotFoundError: No module named ‘models’
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): \ / failed
提示:这里简述项目相关背景:
例如:项目场景:示例:通过蓝牙芯片(HC-05)与手机 APP 通信,每隔 5s 传输一批传感器数据(不是很大)
# 问题描述:
我把服务器训练完的模型拿过来在本地查看
目录接口
----infer
---- 当前文件.py
----yolov5
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='D:/inference_model/tangbao-citie/2021-12-31/best.pt', help='weights path')
opt = parser.parse_args()
# Load pytorch model
print(opt.weights)
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model']
for name, parameters in model.named_parameters():
print(parameters.dtype)
报错
D:\install\anconda\envs\pytorch-test\python.exe D:/pytorch-work/pytorch_infer/cat_model.py
Traceback (most recent call last):
File "D:\pytorch-work\pytorch_infer\cat_model.py", line 23, in <module>
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model']
File "D:\install\anconda\envs\pytorch-test\lib\site-packages\torch\serialization.py", line 607, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "D:\install\anconda\envs\pytorch-test\lib\site-packages\torch\serialization.py", line 882, in _load
result = unpickler.load()
File "D:\install\anconda\envs\pytorch-test\lib\site-packages\torch\serialization.py", line 875, in find_class
return super().find_class(mod_name, name)
ModuleNotFoundError: No module named 'models'
原因分析:
需要放在yolov5目录下 他会找yolov5目录下的models目录
https://github.com/ultralytics/yolov5/issues/353
解决方案:
import sys
sys.path.insert(0, 'D:/pytorch-work/yolov5')
print(sys.path)
import torch
import argparse
import sys
sys.path.insert(0, 'D:/pytorch-work/yolov5')
print(sys.path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='D:/inference_model/tangbao-citie/2021-12-31/best.pt', help='weights path')
opt = parser.parse_args()
# Load pytorch model
print(opt.weights)
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model']
for name, parameters in model.named_parameters():
print(parameters.dtype)