浏览 580

TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple

TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple

def train(verbose = False):

net.train()
loss_list = []
for i,data in enumerate(train_dataloader):
    inputs = data['inputs']
    groundtruths = data['groundtruths']     
    if USE_GPU:
        inputs = Variable(inputs).cuda()
        groundtruths = Variable(groundtruths).cuda()
    else:
        inputs = Variable(inputs)
        groundtruths = Variable(groundtruths)
    #将参数的grad值初始化为0
    optimizer.zero_grad()
    #获得网络输出结果
    out = net(inputs)
    #根据真值计算损失函数的值
    loss = loss_criterion(out,groundtruths)
    #通过优化器优化网络
    loss.backward()
    optimizer.step()
    loss_list.append(loss.item())
return loss_list

def test():

error = 0.0
predictions = []
test_groundtruths = []
# 告诉网络进行测试,不再是训练模式
net.eval() 
for i,data in enumerate(test_dataloader):
    inputs = data['inputs']
    groundtruths = data['groundtruths']     
    if USE_GPU:
        inputs = Variable(inputs).cuda()
        groundtruths = Variable(groundtruths).cuda()
    else:
        inputs = Variable(inputs)
        groundtruths = Variable(groundtruths)
    out = net(inputs)
    error += (error_criterion(out,groundtruths).item()*groundtruths.size(0))
    if USE_GPU:
        predictions.extend(out.cpu().data.numpy().tolist())
        test_groundtruths.extend(groundtruths.cpu().data.numpy().tolist())
    else:
        predictions.extend(out.data.numpy().tolist())
        test_groundtruths.extend(groundtruths.data.numpy().tolist())
average_error = np.sqrt(error/len(test_data_trans))
return np.array(predictions).reshape((len(predictions))),np.array(test_groundtruths).reshape((len(test_groundtruths))),average_error

def main():

#记录程序开始的时间
train_start = time.time()
loss_recorder = []
print('starting training... ')
for epoch in range(EPOCHES):
    # adjust learning rate
    adjust_lr.step()
    loss_list = train(verbose= True)
    loss_recorder.append(np.mean(loss_list))
    print('epoch = %d,loss = %.5f'%(epoch+1,np.mean(loss_list)))
print ('training time = {}s'.format(int((time.time() - train_start))))
# 记录测试开始的时间
test_start = time.time()
predictions, test_groundtruth, average_error = test()
print(predictions.shape)
print(test_groundtruth.shape)
print('test time = {}s'.format(int((time.time() - test_start)+1.0)))
print('average error = ',  average_error)
result = pd.DataFrame(data = {'Q(t+1)':predictions,'Q(t+1)truth':test_groundtruth})
result.to_csv('D:/python目录/pythonProject/STA-LSTM-main/data/output/out_t+1.csv')
torch.save(net,'D:/python目录/pythonProject/STA-LSTM-main/models/sta_lstm_t+1.pth')

if name == 'main':
main()

0

  • 编辑 收藏 删除 结题
  • 追加酬金 (90%的用户在追加酬金后获得了解决方案)

    当前问题酬金

    ¥ 0 (可追加 ¥500)

    支付方式

    扫码支付

    加载中...

    提供问题酬金的用户不参与问题酬金结算和分配

    支付即为同意 《付费问题酬金结算规则》

    3 条回答 默认 最新

    • Dengwenjun1688 2023-02-16 17:47
      关注

      这个错误通常是因为输入参数类型错误引起的。在这个函数中,错误发生在`out = net(inputs

      你可以检查

      for i, data in enumerate(train_dataloader):
          inputs, groundtruths = data[0], data[1]
      

      这样,inputs和groundtruths就被提取出来并赋值给了相应的变量,就可以避免这个错误。

      本回答被题主选为最佳回答 , 对您是否有帮助呢? 本回答被专家选为最佳回答 , 对您是否有帮助呢? 本回答被题主和专家选为最佳回答 , 对您是否有帮助呢?
      按下Enter换行,Ctrl+Enter发表内容
    查看更多回答(2条)

    报告相同问题?

    问题事件

    • 创建了问题 2月16日

    悬赏问题

    • ¥300 Node.js connect ECONNREFUSED 错误 用Python 编译KEITHLEY 2400遇到的问题。 曲面部件制孔路径优化 华为OLT告警删除及屏蔽 微信小程序如何提高加载速度 用python设计一个滑动拼图小游戏 求2方主体各3策略演化博弈模型matlab代码 请教下vscode 下 lua 用哪个字体可以连接符号? ~= ==啥的