import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, ELU, BatchNormalization, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf from tensorflow.keras.datasets import cifar10 from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, ELU, BatchNormalization, Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler import numpy as np import matplotlib.pyplot as plt
データセットの読み込み
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
入力データの正規化
x_train = x_train.astype('float32') / 255 x_test = x_test.astype('float32') / 255
モデルの定義
model = Sequential([ Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]), ELU(), BatchNormalization(), Conv2D(32, (3, 3)), ELU(), MaxPooling2D(pool_size=(2, 2)), Dropout(0.2), BatchNormalization(), Conv2D(64, (3, 3), padding='same'), ELU(), BatchNormalization(), Conv2D(64, (3, 3)), ELU(), MaxPooling2D(pool_size=(2, 2)), Dropout(0.3), BatchNormalization(), Flatten(), Dense(512), ELU(), Dropout(0.4), BatchNormalization(), Dense(10, activation='softmax') ])
モデルのコンパイル
model.compile(optimizer=Adam(learning_rate=1e-3), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
モデルの要約
model.summary()
データ拡張
datagen = ImageDataGenerator( rotation_range=10, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, zoom_range=0.2 )
早期打ち切り
early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='min', restore_best_weights=True)
学習率のスケジューリング
def scheduler(epoch, lr): if epoch < 10: return float(lr) else: return float(lr * tf.math.exp(-0.1).numpy()) lr_scheduler = LearningRateScheduler(scheduler)
モデルの訓練
history = model.fit( datagen.flow(x_train, y_train, batch_size=64), epochs=50, validation_data=(x_test, y_test), verbose=1, callbacks=[early_stopping, lr_scheduler] )
訓練の精度と検証の精度をプロット
plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.plot(history.history['accuracy'], label='Training Accuracy') plt.plot(history.history['val_accuracy'], label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend()
訓練の損失と検証の損失をプロット
plt.subplot(1, 2, 2) plt.plot(history.history['loss'], label='Training Loss') plt.plot(history.history['val_loss'], label='Validation Loss') plt.title('Training and Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.tight_layout() plt.show()
テストデータに対する評価
score = model.evaluate(x_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) 2024-05-10 19:33:29.929639: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-10 19:33:29.929792: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-10 19:33:30.029334: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 2s 0us/step /opt/conda/lib/python3.10/site-packages/keras/src/layers/convolutional/base_conv.py:99: UserWarning: Do not pass an
input_shape
/input_dim
argument to a layer. When using Sequential models, prefer using anInput(shape)
object as the first layer in the model instead. super().init( Model: "sequential" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 32, 32, 32) │ 896 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ elu (ELU) │ (None, 32, 32, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization │ (None, 32, 32, 32) │ 128 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 30, 30, 32) │ 9,248 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ elu_1 (ELU) │ (None, 30, 30, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 15, 15, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 15, 15, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_1 │ (None, 15, 15, 32) │ 128 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (Conv2D) │ (None, 15, 15, 64) │ 18,496 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ elu_2 (ELU) │ (None, 15, 15, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_2 │ (None, 15, 15, 64) │ 256 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 13, 13, 64) │ 36,928 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ elu_3 (ELU) │ (None, 13, 13, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 6, 6, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 6, 6, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_3 │ (None, 6, 6, 64) │ 256 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten (Flatten) │ (None, 2304) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 512) │ 1,180,160 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ elu_4 (ELU) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_2 (Dropout) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_4 │ (None, 512) │ 2,048 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 5,130 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 1,253,674 (4.78 MB) Trainable params: 1,252,266 (4.78 MB) Non-trainable params: 1,408 (5.50 KB) Epoch 1/50 /opt/conda/lib/python3.10/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:120: UserWarning: YourPyDataset
class should callsuper().__init__(**kwargs)
in its constructor.**kwargs
can includeworkers
,use_multiprocessing
,max_queue_size
. Do not pass these arguments tofit()
, as they will be ignored. self._warn_if_super_not_called() 6/782 ━━━━━━━━━━━━━━━━━━━━ 22s 29ms/step - accuracy: 0.1493 - loss: 2.9058 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1715369637.114599 103 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. W0000 00:00:1715369637.135243 103 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update 704/782 ━━━━━━━━━━━━━━━━━━━━ 3s 46ms/step - accuracy: 0.3613 - loss: 1.9338 W0000 00:00:1715369669.371927 101 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update 777/782 ━━━━━━━━━━━━━━━━━━━━ 0s 45ms/step - accuracy: 0.3683 - loss: 1.9052 W0000 00:00:1715369674.194460 103 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update 782/782 ━━━━━━━━━━━━━━━━━━━━ 51s 49ms/step - accuracy: 0.3689 - loss: 1.9030 - val_accuracy: 0.5600 - val_loss: 1.3056 - learning_rate: 0.0010 Epoch 2/50 W0000 00:00:1715369675.279710 102 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.5644 - loss: 1.2240 - val_accuracy: 0.6609 - val_loss: 0.9819 - learning_rate: 0.0010 Epoch 3/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.6344 - loss: 1.0428 - val_accuracy: 0.6541 - val_loss: 1.0148 - learning_rate: 0.0010 Epoch 4/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.6708 - loss: 0.9441 - val_accuracy: 0.6982 - val_loss: 0.9036 - learning_rate: 0.0010 Epoch 5/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.6915 - loss: 0.8878 - val_accuracy: 0.7359 - val_loss: 0.7684 - learning_rate: 0.0010 Epoch 6/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7080 - loss: 0.8411 - val_accuracy: 0.7078 - val_loss: 0.8716 - learning_rate: 0.0010 Epoch 7/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7172 - loss: 0.8128 - val_accuracy: 0.7440 - val_loss: 0.7424 - learning_rate: 0.0010 Epoch 8/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.7289 - loss: 0.7806 - val_accuracy: 0.7822 - val_loss: 0.6465 - learning_rate: 0.0010 Epoch 9/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7389 - loss: 0.7427 - val_accuracy: 0.7448 - val_loss: 0.7539 - learning_rate: 0.0010 Epoch 10/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7406 - loss: 0.7395 - val_accuracy: 0.7293 - val_loss: 0.8287 - learning_rate: 0.0010 Epoch 11/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7521 - loss: 0.7086 - val_accuracy: 0.7673 - val_loss: 0.7102 - learning_rate: 9.0484e-04 Epoch 12/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7648 - loss: 0.6811 - val_accuracy: 0.7883 - val_loss: 0.6180 - learning_rate: 8.1873e-04 Epoch 13/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7701 - loss: 0.6625 - val_accuracy: 0.7576 - val_loss: 0.7172 - learning_rate: 7.4082e-04 Epoch 14/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.7776 - loss: 0.6416 - val_accuracy: 0.7832 - val_loss: 0.6337 - learning_rate: 6.7032e-04 Epoch 15/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7760 - loss: 0.6392 - val_accuracy: 0.8118 - val_loss: 0.5571 - learning_rate: 6.0653e-04 Epoch 16/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.7857 - loss: 0.6161 - val_accuracy: 0.7959 - val_loss: 0.5869 - learning_rate: 5.4881e-04 Epoch 17/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.7866 - loss: 0.6088 - val_accuracy: 0.8115 - val_loss: 0.5555 - learning_rate: 4.9659e-04 Epoch 18/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.7918 - loss: 0.5961 - val_accuracy: 0.7994 - val_loss: 0.5883 - learning_rate: 4.4933e-04 Epoch 19/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.7950 - loss: 0.5864 - val_accuracy: 0.8200 - val_loss: 0.5268 - learning_rate: 4.0657e-04 Epoch 20/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.8032 - loss: 0.5714 - val_accuracy: 0.8290 - val_loss: 0.5057 - learning_rate: 3.6788e-04 Epoch 21/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.8036 - loss: 0.5617 - val_accuracy: 0.8320 - val_loss: 0.4857 - learning_rate: 3.3287e-04 Epoch 22/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.8051 - loss: 0.5646 - val_accuracy: 0.8222 - val_loss: 0.5316 - learning_rate: 3.0119e-04 Epoch 23/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 37ms/step - accuracy: 0.8027 - loss: 0.5640 - val_accuracy: 0.8293 - val_loss: 0.5011 - learning_rate: 2.7253e-04 Epoch 24/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 29s 36ms/step - accuracy: 0.8085 - loss: 0.5496 - val_accuracy: 0.8151 - val_loss: 0.5467 - learning_rate: 2.4660e-04 Epoch 25/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 28s 36ms/step - accuracy: 0.8112 - loss: 0.5422 - val_accuracy: 0.8368 - val_loss: 0.4892 - learning_rate: 2.2313e-04 Epoch 26/50 782/782 ━━━━━━━━━━━━━━━━━━━━ 30s 39ms/step - accuracy: 0.8128 - loss: 0.5331 - val_accuracy: 0.8332 - val_loss: 0.4886 - learning_rate: 2.0190e-04 Epoch 26: early stopping Restoring model weights from the end of the best epoch: 21.Test loss: 0.4857422709465027 Test accuracy: 0.8320000171661377