我正在编写一个GRU,当我试图进行预测时,我会得到一个错误,表明我需要为forward()定义h。在搜索和搜索堆栈溢出数小时后,我尝试了几次,并失去了耐心。
这是一堂课:
class GRUNet(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, n_layers, drop_prob = 0.2): super(GRUNet, self).__init__() self.hidden_dim = hidden_dim self.n_layers = n_layers self.gru = nn.GRU(input_dim, hidden_dim, n_layers, batch_first=True, dropout=drop_prob) self.fc = nn.Linear(hidden_dim, output_dim) self.relu = nn.ReLU() def forward(self, x, h): out, h = self.gru(x,h) out = self.fc(self.relu(out[:,-1])) return out, h def init_hidden(self, batch_size): weight = next(self.parameters()).data hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device) return hidden
然后,我在这里加载模型,并尝试进行预测。这两个都在同一个脚本中。
inputs = np.load('.//Pred//input_list.npy') print(inputs.ndim, inputs.shape) Gmodel = GRUNet(24,256,1,2) Gmodel = torch.load('.//GRU//GRU_1028_48.pkl') Gmodel.eval() pred = Gmodel(inputs)
在不使用任何其他参数的情况下,我得到了以下内容:
Traceback (most recent call last): File ".\grunet.py", line 136, in <module>