Keras多对多分类LSTM错误。不能挤压dim[2],预计维度为1,得到3[[节点categorical_crossentropy/]。

0 人关注

我想做一个LSTM模型,它能够在每个时间段计算三个类别的概率。

我以如下方式定义该模型。

from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()
model.add(layers.LSTM(128,input_shape=(None,input_size)))
model.add(layers.Dense(output_size))
model.add(layers.Softmax())
model.compile(optimizer="sgd",loss=keras.losses.CategoricalCrossentropy())
model.fit(Generator())

生成器提供由一个训练实例组成的批次(我计划使用多个例子的批次,但因为我的训练例子长度不一,所以我现在尽可能保持简单的东西).更准确地说,生成器返回形状为(1,T,input_size)的x和形状为(1,T,3)的y。在处理了几个批次后,keras失败了,出现了以下错误

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-51-ad3e71307f94> in <module>
      1 model.compile(optimizer="sgd",loss=keras.losses.CategoricalCrossentropy())
----> 2 model.fit(Generator())
c:\theproject\venv\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
    106   def _method_wrapper(self, *args, **kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self, *args, **kwargs)
    110     # Running inside `run_distribute_coordinator` already.
c:\theproject\venv\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1096                 batch_size=batch_size):
   1097               callbacks.on_train_batch_begin(step)
-> 1098               tmp_logs = train_function(iterator)
   1099               if data_handler.should_sync:
   1100                 context.async_wait()
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    782       new_tracing_count = self._get_tracing_count()
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    838         # Lifting succeeded, so variables are initialized and we can run the
    839         # stateless function.
--> 840         return self._stateless_fn(*args, **kwds)
    841     else:
    842       canon_args, canon_kwds = \
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2827     with self._lock:
   2828       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2831   @property
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs, cancellation_manager)
   1841       `args` and `kwargs`.
   1842     """
-> 1843     return self._call_flat(
   1844         [t for t in nest.flatten((args, kwargs), expand_composites=True)
   1845          if isinstance(t, (ops.Tensor,
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1921         and executing_eagerly):
   1922       # No tape is watching; skip to running the function.
-> 1923       return self._build_call_outputs(self._inference_function.call(
   1924           ctx, args, cancellation_manager=cancellation_manager))
   1925     forward_backward = self._select_forward_and_backward_functions(
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    543       with _InterpolateFunctionError(self):
    544         if cancellation_manager is None:
--> 545           outputs = execute.execute(
    546               str(self.signature.name),
    547               num_outputs=self._num_outputs,
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
InvalidArgumentError:  Can not squeeze dim[2], expected a dimension of 1, got 3
     [[node categorical_crossentropy/remove_squeezable_dimensions/Squeeze (defined at <ipython-input-51-ad3e71307f94>:2) ]] [Op:__inference_train_function_16213]
Function call stack:
train_function

这里是一个样本(不是唯一的一个)的X和Y,它的失败之处。

          0    1    2    3    4    5    6    7    8    9    ...  169  170  \
0    0.585963  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0   
1    0.831822  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0   
2    0.831822  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0   
3    0.831822  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  ...  0.0  0.0   
4    0.831822  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0   
..        ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...  ...   
984  0.989131  1.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  ...  0.0  0.0   
985  0.989927  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0   
986  0.990885  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  ...  0.0  0.0   
987  0.990911  0.0  1.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  ...  0.0  0.0   
988  0.991843  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0   
     171  172  173  174  175  176  177  178  
0    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
1    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
2    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
3    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
4    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
..   ...  ...  ...  ...  ...  ...  ...  ...  
984  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
985  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
986  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
987  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  
988  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
       0    1    2
0    1.0  0.0  0.0
1    1.0  0.0  0.0
2    0.0  0.0  1.0
3    1.0  0.0  0.0
4    0.0  0.0  1.0
..   ...  ...  ...
984  1.0  0.0  0.0
985  0.0  0.0  1.0
986  0.0  1.0  0.0