我想做一个LSTM模型,它能够在每个时间段计算三个类别的概率。
我以如下方式定义该模型。
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()
model.add(layers.LSTM(128,input_shape=(None,input_size)))
model.add(layers.Dense(output_size))
model.add(layers.Softmax())
model.compile(optimizer="sgd",loss=keras.losses.CategoricalCrossentropy())
model.fit(Generator())
生成器提供由一个训练实例组成的批次(我计划使用多个例子的批次,但因为我的训练例子长度不一,所以我现在尽可能保持简单的东西).更准确地说,生成器返回形状为(1,T,input_size)的x和形状为(1,T,3)的y。在处理了几个批次后,keras失败了,出现了以下错误
InvalidArgumentError Traceback (most recent call last)
<ipython-input-51-ad3e71307f94> in <module>
1 model.compile(optimizer="sgd",loss=keras.losses.CategoricalCrossentropy())
----> 2 model.fit(Generator())
c:\theproject\venv\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
--> 108 return method(self, *args, **kwargs)
110 # Running inside `run_distribute_coordinator` already.
c:\theproject\venv\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1096 batch_size=batch_size):
1097 callbacks.on_train_batch_begin(step)
-> 1098 tmp_logs = train_function(iterator)
1099 if data_handler.should_sync:
1100 context.async_wait()
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
782 new_tracing_count = self._get_tracing_count()
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
838 # Lifting succeeded, so variables are initialized and we can run the
839 # stateless function.
--> 840 return self._stateless_fn(*args, **kwds)
841 else:
842 canon_args, canon_kwds = \
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
2827 with self._lock:
2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2831 @property
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs, cancellation_manager)
1841 `args` and `kwargs`.
1842 """
-> 1843 return self._call_flat(
1844 [t for t in nest.flatten((args, kwargs), expand_composites=True)
1845 if isinstance(t, (ops.Tensor,
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1921 and executing_eagerly):
1922 # No tape is watching; skip to running the function.
-> 1923 return self._build_call_outputs(self._inference_function.call(
1924 ctx, args, cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
543 with _InterpolateFunctionError(self):
544 if cancellation_manager is None:
--> 545 outputs = execute.execute(
546 str(self.signature.name),
547 num_outputs=self._num_outputs,
c:\theproject\venv\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 try:
58 ctx.ensure_initialized()
---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
InvalidArgumentError: Can not squeeze dim[2], expected a dimension of 1, got 3
[[node categorical_crossentropy/remove_squeezable_dimensions/Squeeze (defined at <ipython-input-51-ad3e71307f94>:2) ]] [Op:__inference_train_function_16213]
Function call stack:
train_function
这里是一个样本(不是唯一的一个)的X和Y,它的失败之处。
0 1 2 3 4 5 6 7 8 9 ... 169 170 \
0 0.585963 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0
1 0.831822 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0
2 0.831822 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0
3 0.831822 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 ... 0.0 0.0
4 0.831822 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0
.. ... ... ... ... ... ... ... ... ... ... ... ... ...
984 0.989131 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 ... 0.0 0.0
985 0.989927 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0
986 0.990885 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 ... 0.0 0.0
987 0.990911 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 ... 0.0 0.0
988 0.991843 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0
171 172 173 174 175 176 177 178
0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
.. ... ... ... ... ... ... ... ...
984 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
985 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
986 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
987 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
988 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0 1 2
0 1.0 0.0 0.0
1 1.0 0.0 0.0
2 0.0 0.0 1.0
3 1.0 0.0 0.0
4 0.0 0.0 1.0
.. ... ... ...
984 1.0 0.0 0.0
985 0.0 0.0 1.0
986 0.0 1.0 0.0