Bihaqo / t3f

Tensor Train decomposition on TensorFlow
https://t3f.readthedocs.io/en/latest/index.html
MIT License
218 stars 56 forks source link

How to use in 3d array? #223

Open wuqingle opened 7 months ago

wuqingle commented 7 months ago

Hello, I tried to input three-dimensional data into the model, and replaced the dense layer with t3f.nn.KerasDense, but a None dimension was missing. How to use t3f.nn.KerasDense in three-dimensional data?

_import numpy as np from keras.models import Sequential from keras.layers import Conv1D, MaxPooling1D, Dropout, Dense, Flatten, Reshape

from keras.layers import Reshape

This generates some test sample for me to check your code

X_train = np.random.rand(100, 4, 400) Y_train = np.random.rand(100, 2)

model = Sequential()

model.add(Conv1D(32, 3, activation='relu', input_shape=(4, 400)))

model.add(MaxPooling1D(2))

model.add(Dropout(0.5))

model.add(Flatten()) # <- You need a flatten here

tt_layer = t3f.nn.KerasDense(input_dims=[4, 4, 2, 1], output_dims=[4, 4, 2, 1], tt_rank=16, activation='relu') model.add(tt_layer)#

model.add(Dense(32, activation='relu'))

model.add(Reshape((1,32)))

model.add(Flatten()) model.add(Dense(2, activation='sigmoid')) # <- the last dense must have output 2

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) model.summary() model.fit(X_train, Y_train, batch_size=16, epochs=10)

_ model summary is

Model: "sequential_11"


Layer (type) Output Shape Param #

conv1d_10 (Conv1D) (None, 2, 32) 38432

tt_dense_10 (KerasDense) (2, 32) 5424

flatten_7 (Flatten) (2, 32) 0

dense_9 (Dense) (2, 2) 66

================================================================= Total params: 43922 (171.57 KB) Trainable params: 43922 (171.57 KB) Non-trainable params: 0 (0.00 Byte)

An error occurred in model.fit File "/usr/local/lib/python3.10/dist-packages/t3f/ops.py", line 231, in tt_dense_matmul Input to reshape is a tensor with 1024 values, but the requested shape has 64 [[{{node sequential_11/tt_dense_10/t3f_matmul/Reshape_4}}]] [Op:__inference_train_function_9879]