Я новичок в TensorFlow и форматирую некоторые данные для передачи в рекуррентную нейронную сеть. Мои данные задаются трехмерным тензором, введенным в заполнитель x
. Я хочу разбить x
по 3-му измерению, и для этого у меня есть (обратите внимание, что n_timesteps
соответствует длине x
по 3-му измерению):
# Split the previous 3d tensor to get a list of 'n_timesteps' 2d tensors of
# shape (batch_size, features_dimension)
x = tf.split (x, n_timesteps, axis = 2)
Хотя, как я пробовал с numpy
:
x = np.split (x, n_timesteps, axis = 2)
Если x
является трехмерным ndarray
, то np.split
вернет список n_timesteps
массивов с измерением 3, так что 3-е измерение является одноэлементным. С numpy
я знаю, что могу легко решить эту проблему, используя np.squeeze
вместе с пониманием списка, чтобы удалить одноэлементное измерение:
x = [np.squeeze(a, axis=2) for a in np.split(x, n_timesteps, axis=2)]
Но как я могу сделать то же самое на TF?