想请教各位一个问题!!!!!想知道这是为什么,看官方api文档说summary()中的参数是可选的,可是我在学习时想使用summary()打印网络的基础结构和参数信息看看,报如下错误:
---------------------------------------------------------------------------ValueError Traceback (most recent call last) in
----> 1 model_cnn.summary()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in summary(self, input_size, dtype)
1879 else:
1880 _input_size = self._inputs
-> 1881 return summary(self.network, _input_size, dtype)
1883 def _verify_spec(self, specs, shapes=None, dtypes=None, is_input=False):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model_summary.py in summary(net, input_size, dtypes)
148 _input_size = _check_input(_input_size)
--> 149 result, params_info = summary_string(net, _input_size, dtypes)
150 print(result)
in summary_string(model, input_size, dtypes)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/base.py in _decorate_function(func, *args, **kwargs)
313 def _decorate_function(func, *args, **kwargs):
314 with self:
--> 315 return func(*args, **kwargs)
317 @decorator.decorator
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model_summary.py in summary_string(model, input_size, dtypes)
275 # make a forward pass
--> 276 model(*x)
278 # remove these hooks
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py in __call__(self, *inputs, **kwargs)
900 self._built = True
--> 902 outputs = self.forward(*inputs, **kwargs)
904 for forward_post_hook in self._forward_post_hooks.values():
in forward(self, text, seq_len)
32 def forward(self, text, seq_len=None):
33 # Shape: (batch_size, num_tokens, embedding_dim)
---> 34 embedded_text = self.embedder(text)
35 print('after word-embeding:', embedded_text.shape)
36 # Shape: (batch_size, len(ngram_filter_sizes)*num_filter)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py in __call__(self, *inputs, **kwargs)
900 self._built = True
--> 902 outputs = self.forward(*inputs, **kwargs)
904 for forward_post_hook in self._forward_post_hooks.values():
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/common.py in forward(self, x)
1361 padding_idx=self._padding_idx,
1362 sparse=self._sparse,
-> 1363 name=self._name)
1365 def extra_repr(self):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/functional/input.py in embedding(x, weight, padding_idx, sparse, name)
198 return core.ops.lookup_table_v2(
199 weight, x, 'is_sparse', sparse, 'is_distributed', False,
--> 200 'remote_prefetch', False, 'padding_idx', padding_idx)
201 else:
202 helper = LayerHelper('embedding', **locals())
ValueError: (InvalidArgument) Tensor holds the wrong type, it holds float, but desires to be int64_t.
[Hint: Expected valid == true, but received valid:0 != true:1.] (at /paddle/paddle/fluid/framework/tensor_impl.h:33)
[operator < lookup_table_v2 > error]
源码如下:
class CNNModel(nn.Layer):
def __init__(self,
vocab_size,
num_classes,
emb_dim=128,
padding_idx=0,
num_filter=128,
ngram_filter_sizes=(3, ),
fc_hidden_size=96):
super().__init__()
self.embedder = nn.Embedding(
vocab_size, emb_dim, padding_idx=padding_idx)
self.encoder = ppnlp.seq2vec.CNNEncoder(
emb_dim=emb_dim,
num_filter=num_filter,
ngram_filter_sizes=ngram_filter_sizes)
self.fc = nn.Linear(self.encoder.get_output_dim(), fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len=None):
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(text)
print('after word-embeding:', embedded_text.shape)
# Shape: (batch_size, len(ngram_filter_sizes)*num_filter)
encoder_out = self.encoder(embedded_text)
encoder_out = paddle.tanh(encoder_out)
# Shape: (batch_size, fc_hidden_size)
fc_out = self.fc(encoder_out)
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
return logits
model_cnn = CNNModel(27665,2)
model_cnn = paddle.Model(model_cnn)
model_cnn.summary()
通过查看GitHub上的issues,找到了解决方法,是自己学艺不精。。。哎。。。。。这样就行了 [代码] 但是我不知道这个inputsize对不对QAQ [代码] 这个打印出来是这样的: [代码]
通过查看GitHub上的issues,找到了解决方法,是自己学艺不精。。。哎。。。。。这样就行了
但是我不知道这个inputsize对不对QAQ
这个打印出来是这样的: