1.训练中断
Loss: 0.4968: 88%|████████████████████▎ | 17173/19457 [44:41<05:56, 6.40it/s]
Traceback (most recent call last):
File "/media/newData/user/dxq/LaTeX-OCR/train.py", line 97, in <module>
train(args)
File "/media/newData/user/dxq/LaTeX-OCR/train.py", line 53, in train
loss.backward()
File "/home/appuser/.local/lib/python3.6/site-packages/torch/tensor.py", line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/appuser/.local/lib/python3.6/site-packages/torch/autograd/__init__.py", line 147, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: transform: failed to synchronize: cudaErrorAssert: device-side assert triggered
原因:token数量超出参数限制
解决:计算最大token数量,重新设置
2. 无法训练
Traceback (most recent call last):
File "/media/newData/user/dxq/LaTeX-OCR/train.py", line 100, in <module>
train(args)
File "/media/newData/user/dxq/LaTeX-OCR/train.py", line 54, in train
encoded = encoder(im.to(device))
File "/home/appuser/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/appuser/.local/lib/python3.6/site-packages/timm/models/vision_transformer.py", line 374, in forward
x = self.forward_features(x)
File "/media/newData/user/dxq/LaTeX-OCR/models.py", line 80, in forward_features
x += self.pos_embed[:, pos_emb_ind]
RuntimeError: The size of tensor a (25) must match the size of tensor b (12) at non-singleton dimension 1
出现问题的代码在这里
def forward_features(self, x):
B, c, h_init, w_init = x.shape
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
total_channel =x.shape[1]
h, w = math.ceil(h_init/self.patch_size), math.ceil(w_init/self.patch_size)
# if h*w+1 != total_channel:
# h, w = w_init // self.patch_size,w_init // self.patch_size
pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
pos_embed = self.pos_embed
x += self.pos_embed[:, pos_emb_ind]
# x = x + self.pos_embed
x = self.pos_drop(x)
# torch.cat((torch.zeros(1),repeat(torch.arange(h) * (self.width // self.patch_size - w), 'h -> (h w)', w=w) + torch.arange(h * w)+1), dim=0).long()
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
生成的维度不一致,应该向上取整
h, w = w_init // self.patch_size,w_init // self.patch_size
h, w = math.ceil(h_init/self.patch_size), math.ceil(w_init/self.patch_size)
心塞,疯狂debug