File "/home/guodong/github/guodong/Bert-Chinese-Text-Classification-Pytorch/to_onnx.py", line 64, in <module>
torch.onnx.export(model,
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/onnx/utils.py", line 504, in export
_export(
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/onnx/utils.py", line 1529, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/onnx/utils.py", line 1111, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/onnx/utils.py", line 987, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/onnx/utils.py", line 891, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/jit/_trace.py", line 1184, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/guodong/miniconda3/envs/onnx_gpu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1178, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: Model.forward() takes 2 positional arguments but 4 were given
报错显示,forward()函数预期传入两个参数,但是实际传入了4个。
看一下模型的forward函数的定义(models/bert.py中)
class Model(nn.Module):
... ...
def forward(self, x):
context = x[0] # 输入的句子
mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
_, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
out = self.fc(pooled)
return out
#!/usr/bin/env python
# coding=utf-8
import numpy as np
import onnxruntime as ort
import pred
def predict(sess, text):
ids, seq_len, mask = pred.build_predict_text_raw(text)
input = {
'ids': np.array([ids]),
'mask': np.array([mask]),
outs = sess.run(None, input)
num = np.argmax(outs)
return pred.key[num]
if __name__ == '__main__':
sess = ort.InferenceSession('./model.onnx', providers=['CUDAExecutionProvider'])
t = '天问一号着陆火星一周年'
res = predict(sess, t)
print('%s is %s' % (t, res))
最终输出:
天问一号着陆火星一周年 is science
性能对比
模型转换为ONNX模型后,性能表现如何呢?接下来可以验证一下。 先写一些辅助函数:
def load_title(fname):
从一个文件里加载新闻标题
ts = []
with open(fname) as f:
for line in f.readlines():
ts.append(line.strip())
return ts
def batch_predict(ts, predict_fun, name):
使用不同的预测函数,批量预测,并统计耗时
print('')
a = time.time()
for t in ts:
res = predict_fun(t)
print('%s is %s' % (t, res))
b = time.time()
print('%s cost: %.4f' % (name, (b - a)))
杭州购房政策大松绑 is realty
兰州野生动物园观光车侧翻事故新进展:2人经抢救无效死亡 is society
4个小学生离家出走30公里想去广州塔 is society
朱一龙戏路打通电影电视剧 is entertainment
天问一号着陆火星一周年 is science
网友拍到天舟五号穿云而出 is science
ONNX_CUDA cost: 0.0406
杭州购房政策大松绑 is realty
兰州野生动物园观光车侧翻事故新进展:2人经抢救无效死亡 is society
4个小学生离家出走30公里想去广州塔 is society
朱一龙戏路打通电影电视剧 is entertainment