Cnn completo
Exemplo completo juntando tudo que já vimos e mais um pouco...¶
Vamos usar o conjunto de dados CIFAR-10, que é um conjunto de dados de 60.000 imagens de 10 classes, com 6.000 imagens por classe.
As imagens são de tamanho 32x32 pixels com três canais de cores (RGB).
Ao longo do notebook vamos relembrar alguns conceitos e conhecer novos.
LEMBRETE IMPORTANTE
: Lembre-se de setar o colab para usar a GPU.
Configurando o google drive¶
Vamos setar o google drive para salvar o modelo durante o treinamento
from google.colab import drive
drive.mount('/content/drive')
# Vamos definir o caminho onde o modelo será salvo no Google Drive
model_save_path = '/content/drive/MyDrive/checkpoints/cifar10_best_model.h5'
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Importando e Preparando os Dados CIFAR-10¶
Primeiro, vamos importar e preparar os dados CIFAR-10:
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
# Carregando o conjunto de dados CIFAR-10
(x_train, y_train), (x_val, y_val) = cifar10.load_data()
# Normalizando os valores dos pixels para o intervalo [0, 1]
x_train = x_train.astype('float32') / 255.0
x_val = x_val.astype('float32') / 255.0
# Convertendo os rótulos para vetores one-hot
y_train = to_categorical(y_train, 10)
y_val = to_categorical(y_val, 10)
Definindo o Modelo¶
Vamos definir o modelo para o dataset:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Definindo o modelo
model = Sequential()
# Camadas convolucionais com número crescente de filtros
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
# Camada Flatten para converter a saída das camadas convolucionais em um vetor 1D
model.add(Flatten())
# Camadas totalmente conectadas
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Última camada totalmente conectada com 10 saídas (10 classes de categoria de imagem)
model.add(Dense(10, activation='softmax'))
# Compilando o modelo
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Resumo do modelo
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_3 (Conv2D) (None, 30, 30, 32) 896 batch_normalization_3 (Bat (None, 30, 30, 32) 128 chNormalization) max_pooling2d_3 (MaxPoolin (None, 15, 15, 32) 0 g2D) conv2d_4 (Conv2D) (None, 13, 13, 64) 18496 batch_normalization_4 (Bat (None, 13, 13, 64) 256 chNormalization) max_pooling2d_4 (MaxPoolin (None, 6, 6, 64) 0 g2D) conv2d_5 (Conv2D) (None, 4, 4, 128) 73856 batch_normalization_5 (Bat (None, 4, 4, 128) 512 chNormalization) max_pooling2d_5 (MaxPoolin (None, 2, 2, 128) 0 g2D) flatten_1 (Flatten) (None, 512) 0 dense_3 (Dense) (None, 128) 65664 dropout_2 (Dropout) (None, 128) 0 dense_4 (Dense) (None, 64) 8256 dropout_3 (Dropout) (None, 64) 0 dense_5 (Dense) (None, 10) 650 ================================================================= Total params: 168714 (659.04 KB) Trainable params: 168266 (657.29 KB) Non-trainable params: 448 (1.75 KB) _________________________________________________________________
Configurando Callbacks¶
Vamos configurar as callbacks
: ModelCheckpoint
, EarlyStopping
e ReduceLROnPlateau
:
# Configurando a callback ModelCheckpoint para salvar o modelo no google drive
checkpoint = ModelCheckpoint(model_save_path,
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=1)
# Configurando a callback EarlyStopping
early_stopping = EarlyStopping(monitor='val_loss',
patience=10,
verbose=1,
restore_best_weights=True)
# Configurando a callback ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss',
factor=0.2,
patience=5,
min_lr=0.001,
verbose=1)
Configurando Data Augmentation¶
Vamos configurar a data augmentation usando ImageDataGenerator:
# Configurando o data augmentation
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
zoom_range=0.2
)
# Gerando dados de treinamento augmentados
train_generator = datagen.flow(x_train, y_train, batch_size=32)
Treinando o Modelo¶
Finalmente chegou o momento de treinar o modelo com as callbacks e data augmentation:
# Treinando o modelo com as callbacks e data augmentation
history = model.fit(train_generator,
epochs=100,
validation_data=(x_val, y_val),
callbacks=[checkpoint, early_stopping, reduce_lr])
Epoch 1/100 1562/1563 [============================>.] - ETA: 0s - loss: 2.0423 - accuracy: 0.2451 Epoch 1: val_accuracy improved from -inf to 0.28590, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 43s 23ms/step - loss: 2.0421 - accuracy: 0.2451 - val_loss: 2.0514 - val_accuracy: 0.2859 - lr: 0.0010 Epoch 2/100 1/1563 [..............................] - ETA: 52s - loss: 1.8313 - accuracy: 0.1562
/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`. saving_api.save_model(
1561/1563 [============================>.] - ETA: 0s - loss: 1.7979 - accuracy: 0.3332 Epoch 2: val_accuracy improved from 0.28590 to 0.45100, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 36s 23ms/step - loss: 1.7977 - accuracy: 0.3333 - val_loss: 1.4733 - val_accuracy: 0.4510 - lr: 0.0010 Epoch 3/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.6728 - accuracy: 0.3916 Epoch 3: val_accuracy improved from 0.45100 to 0.48650, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 37s 23ms/step - loss: 1.6729 - accuracy: 0.3917 - val_loss: 1.3924 - val_accuracy: 0.4865 - lr: 0.0010 Epoch 4/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.6007 - accuracy: 0.4219 Epoch 4: val_accuracy did not improve from 0.48650 1563/1563 [==============================] - 37s 23ms/step - loss: 1.6007 - accuracy: 0.4219 - val_loss: 1.5411 - val_accuracy: 0.4524 - lr: 0.0010 Epoch 5/100 1563/1563 [==============================] - ETA: 0s - loss: 1.5361 - accuracy: 0.4527 Epoch 5: val_accuracy improved from 0.48650 to 0.49570, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 36s 23ms/step - loss: 1.5361 - accuracy: 0.4527 - val_loss: 1.4323 - val_accuracy: 0.4957 - lr: 0.0010 Epoch 6/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.4981 - accuracy: 0.4708 Epoch 6: val_accuracy did not improve from 0.49570 1563/1563 [==============================] - 35s 22ms/step - loss: 1.4981 - accuracy: 0.4708 - val_loss: 1.4484 - val_accuracy: 0.4790 - lr: 0.0010 Epoch 7/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.4468 - accuracy: 0.4942 Epoch 7: val_accuracy improved from 0.49570 to 0.56870, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 35s 23ms/step - loss: 1.4468 - accuracy: 0.4941 - val_loss: 1.1967 - val_accuracy: 0.5687 - lr: 0.0010 Epoch 8/100 1563/1563 [==============================] - ETA: 0s - loss: 1.4102 - accuracy: 0.5118 Epoch 8: val_accuracy improved from 0.56870 to 0.60030, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 36s 23ms/step - loss: 1.4102 - accuracy: 0.5118 - val_loss: 1.1149 - val_accuracy: 0.6003 - lr: 0.0010 Epoch 9/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.3760 - accuracy: 0.5249 Epoch 9: val_accuracy did not improve from 0.60030 1563/1563 [==============================] - 34s 22ms/step - loss: 1.3763 - accuracy: 0.5249 - val_loss: 1.2569 - val_accuracy: 0.5386 - lr: 0.0010 Epoch 10/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.3509 - accuracy: 0.5376 Epoch 10: val_accuracy improved from 0.60030 to 0.63180, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 35s 22ms/step - loss: 1.3510 - accuracy: 0.5376 - val_loss: 1.0326 - val_accuracy: 0.6318 - lr: 0.0010 Epoch 11/100 1563/1563 [==============================] - ETA: 0s - loss: 1.3254 - accuracy: 0.5463 Epoch 11: val_accuracy did not improve from 0.63180 1563/1563 [==============================] - 36s 23ms/step - loss: 1.3254 - accuracy: 0.5463 - val_loss: 1.0348 - val_accuracy: 0.6258 - lr: 0.0010 Epoch 12/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.3049 - accuracy: 0.5532 Epoch 12: val_accuracy did not improve from 0.63180 1563/1563 [==============================] - 35s 23ms/step - loss: 1.3051 - accuracy: 0.5532 - val_loss: 1.1351 - val_accuracy: 0.6115 - lr: 0.0010 Epoch 13/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.2862 - accuracy: 0.5637 Epoch 13: val_accuracy improved from 0.63180 to 0.63310, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 35s 22ms/step - loss: 1.2863 - accuracy: 0.5636 - val_loss: 1.0346 - val_accuracy: 0.6331 - lr: 0.0010 Epoch 14/100 1563/1563 [==============================] - ETA: 0s - loss: 1.2669 - accuracy: 0.5699 Epoch 14: val_accuracy did not improve from 0.63310 1563/1563 [==============================] - 35s 23ms/step - loss: 1.2669 - accuracy: 0.5699 - val_loss: 1.2640 - val_accuracy: 0.5506 - lr: 0.0010 Epoch 15/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.2572 - accuracy: 0.5741 Epoch 15: val_accuracy did not improve from 0.63310 1563/1563 [==============================] - 34s 22ms/step - loss: 1.2573 - accuracy: 0.5740 - val_loss: 1.1206 - val_accuracy: 0.6215 - lr: 0.0010 Epoch 16/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.2447 - accuracy: 0.5805 Epoch 16: val_accuracy did not improve from 0.63310 1563/1563 [==============================] - 35s 23ms/step - loss: 1.2446 - accuracy: 0.5805 - val_loss: 1.1991 - val_accuracy: 0.6057 - lr: 0.0010 Epoch 17/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.2226 - accuracy: 0.5873 Epoch 17: val_accuracy improved from 0.63310 to 0.64720, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 35s 22ms/step - loss: 1.2227 - accuracy: 0.5873 - val_loss: 0.9827 - val_accuracy: 0.6472 - lr: 0.0010 Epoch 18/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.2050 - accuracy: 0.5965 Epoch 18: val_accuracy did not improve from 0.64720 1563/1563 [==============================] - 35s 22ms/step - loss: 1.2047 - accuracy: 0.5966 - val_loss: 1.0118 - val_accuracy: 0.6391 - lr: 0.0010 Epoch 19/100 1563/1563 [==============================] - ETA: 0s - loss: 1.1952 - accuracy: 0.5983 Epoch 19: val_accuracy improved from 0.64720 to 0.65800, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 37s 23ms/step - loss: 1.1952 - accuracy: 0.5983 - val_loss: 0.9626 - val_accuracy: 0.6580 - lr: 0.0010 Epoch 20/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.1898 - accuracy: 0.6032 Epoch 20: val_accuracy improved from 0.65800 to 0.68460, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 37s 24ms/step - loss: 1.1898 - accuracy: 0.6032 - val_loss: 0.9178 - val_accuracy: 0.6846 - lr: 0.0010 Epoch 21/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.1815 - accuracy: 0.6051 Epoch 21: val_accuracy did not improve from 0.68460 1563/1563 [==============================] - 37s 23ms/step - loss: 1.1815 - accuracy: 0.6051 - val_loss: 1.0920 - val_accuracy: 0.6279 - lr: 0.0010 Epoch 22/100 1563/1563 [==============================] - ETA: 0s - loss: 1.1674 - accuracy: 0.6085 Epoch 22: val_accuracy improved from 0.68460 to 0.69460, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 36s 23ms/step - loss: 1.1674 - accuracy: 0.6085 - val_loss: 0.8778 - val_accuracy: 0.6946 - lr: 0.0010 Epoch 23/100 1563/1563 [==============================] - ETA: 0s - loss: 1.1468 - accuracy: 0.6188 Epoch 23: val_accuracy improved from 0.69460 to 0.71950, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 36s 23ms/step - loss: 1.1468 - accuracy: 0.6188 - val_loss: 0.8091 - val_accuracy: 0.7195 - lr: 0.0010 Epoch 24/100 1563/1563 [==============================] - ETA: 0s - loss: 1.1462 - accuracy: 0.6197 Epoch 24: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 35s 23ms/step - loss: 1.1462 - accuracy: 0.6197 - val_loss: 0.9789 - val_accuracy: 0.6681 - lr: 0.0010 Epoch 25/100 1563/1563 [==============================] - ETA: 0s - loss: 1.1482 - accuracy: 0.6186 Epoch 25: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 36s 23ms/step - loss: 1.1482 - accuracy: 0.6186 - val_loss: 0.9466 - val_accuracy: 0.6623 - lr: 0.0010 Epoch 26/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.1343 - accuracy: 0.6254 Epoch 26: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 35s 23ms/step - loss: 1.1342 - accuracy: 0.6254 - val_loss: 0.8559 - val_accuracy: 0.7100 - lr: 0.0010 Epoch 27/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.1221 - accuracy: 0.6244 Epoch 27: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 36s 23ms/step - loss: 1.1218 - accuracy: 0.6245 - val_loss: 0.9000 - val_accuracy: 0.6829 - lr: 0.0010 Epoch 28/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.1122 - accuracy: 0.6333 Epoch 28: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 36s 23ms/step - loss: 1.1123 - accuracy: 0.6333 - val_loss: 0.8558 - val_accuracy: 0.7058 - lr: 0.0010 Epoch 29/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.1095 - accuracy: 0.6322 Epoch 29: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 36s 23ms/step - loss: 1.1096 - accuracy: 0.6322 - val_loss: 0.9187 - val_accuracy: 0.6892 - lr: 0.0010 Epoch 30/100 1563/1563 [==============================] - ETA: 0s - loss: 1.1052 - accuracy: 0.6344 Epoch 30: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 35s 22ms/step - loss: 1.1052 - accuracy: 0.6344 - val_loss: 0.8275 - val_accuracy: 0.7135 - lr: 0.0010 Epoch 31/100 1563/1563 [==============================] - ETA: 0s - loss: 1.1002 - accuracy: 0.6357 Epoch 31: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 35s 23ms/step - loss: 1.1002 - accuracy: 0.6357 - val_loss: 1.0448 - val_accuracy: 0.6508 - lr: 0.0010 Epoch 32/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0892 - accuracy: 0.6390 Epoch 32: val_accuracy did not improve from 0.71950 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0892 - accuracy: 0.6390 - val_loss: 0.8500 - val_accuracy: 0.7127 - lr: 0.0010 Epoch 33/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0825 - accuracy: 0.6395 Epoch 33: val_accuracy improved from 0.71950 to 0.73240, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 35s 22ms/step - loss: 1.0825 - accuracy: 0.6395 - val_loss: 0.7995 - val_accuracy: 0.7324 - lr: 0.0010 Epoch 34/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.0815 - accuracy: 0.6441 Epoch 34: val_accuracy did not improve from 0.73240 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0817 - accuracy: 0.6440 - val_loss: 0.7821 - val_accuracy: 0.7324 - lr: 0.0010 Epoch 35/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.0743 - accuracy: 0.6473 Epoch 35: val_accuracy did not improve from 0.73240 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0742 - accuracy: 0.6473 - val_loss: 0.7895 - val_accuracy: 0.7282 - lr: 0.0010 Epoch 36/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.0695 - accuracy: 0.6486 Epoch 36: val_accuracy did not improve from 0.73240 1563/1563 [==============================] - 37s 24ms/step - loss: 1.0693 - accuracy: 0.6487 - val_loss: 0.8054 - val_accuracy: 0.7223 - lr: 0.0010 Epoch 37/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.0633 - accuracy: 0.6510 Epoch 37: val_accuracy improved from 0.73240 to 0.73660, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 35s 23ms/step - loss: 1.0631 - accuracy: 0.6511 - val_loss: 0.7637 - val_accuracy: 0.7366 - lr: 0.0010 Epoch 38/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.0701 - accuracy: 0.6509 Epoch 38: val_accuracy did not improve from 0.73660 1563/1563 [==============================] - 37s 23ms/step - loss: 1.0701 - accuracy: 0.6509 - val_loss: 0.8457 - val_accuracy: 0.7136 - lr: 0.0010 Epoch 39/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.0467 - accuracy: 0.6584 Epoch 39: val_accuracy did not improve from 0.73660 1563/1563 [==============================] - 38s 24ms/step - loss: 1.0463 - accuracy: 0.6585 - val_loss: 0.8873 - val_accuracy: 0.7043 - lr: 0.0010 Epoch 40/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.0546 - accuracy: 0.6548 Epoch 40: val_accuracy did not improve from 0.73660 1563/1563 [==============================] - 37s 24ms/step - loss: 1.0547 - accuracy: 0.6547 - val_loss: 0.8820 - val_accuracy: 0.6985 - lr: 0.0010 Epoch 41/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.0483 - accuracy: 0.6572 Epoch 41: val_accuracy did not improve from 0.73660 1563/1563 [==============================] - 37s 24ms/step - loss: 1.0483 - accuracy: 0.6572 - val_loss: 1.1795 - val_accuracy: 0.6351 - lr: 0.0010 Epoch 42/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0463 - accuracy: 0.6583 Epoch 42: val_accuracy improved from 0.73660 to 0.74890, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 35s 22ms/step - loss: 1.0463 - accuracy: 0.6583 - val_loss: 0.7433 - val_accuracy: 0.7489 - lr: 0.0010 Epoch 43/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0436 - accuracy: 0.6599 Epoch 43: val_accuracy did not improve from 0.74890 1563/1563 [==============================] - 38s 25ms/step - loss: 1.0436 - accuracy: 0.6599 - val_loss: 0.8552 - val_accuracy: 0.7113 - lr: 0.0010 Epoch 44/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0321 - accuracy: 0.6626 Epoch 44: val_accuracy did not improve from 0.74890 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0321 - accuracy: 0.6626 - val_loss: 0.7819 - val_accuracy: 0.7274 - lr: 0.0010 Epoch 45/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.0337 - accuracy: 0.6601 Epoch 45: val_accuracy did not improve from 0.74890 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0336 - accuracy: 0.6602 - val_loss: 0.7743 - val_accuracy: 0.7347 - lr: 0.0010 Epoch 46/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0330 - accuracy: 0.6638 Epoch 46: val_accuracy improved from 0.74890 to 0.75150, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0330 - accuracy: 0.6638 - val_loss: 0.7259 - val_accuracy: 0.7515 - lr: 0.0010 Epoch 47/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.0208 - accuracy: 0.6666 Epoch 47: val_accuracy improved from 0.75150 to 0.76160, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 35s 23ms/step - loss: 1.0209 - accuracy: 0.6665 - val_loss: 0.7046 - val_accuracy: 0.7616 - lr: 0.0010 Epoch 48/100 1562/1563 [============================>.] - ETA: 0s - loss: 1.0166 - accuracy: 0.6695 Epoch 48: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 35s 22ms/step - loss: 1.0166 - accuracy: 0.6695 - val_loss: 0.9199 - val_accuracy: 0.6951 - lr: 0.0010 Epoch 49/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.0137 - accuracy: 0.6683 Epoch 49: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 35s 23ms/step - loss: 1.0139 - accuracy: 0.6682 - val_loss: 0.7090 - val_accuracy: 0.7582 - lr: 0.0010 Epoch 50/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.0255 - accuracy: 0.6638 Epoch 50: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 35s 22ms/step - loss: 1.0254 - accuracy: 0.6638 - val_loss: 0.7449 - val_accuracy: 0.7444 - lr: 0.0010 Epoch 51/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0145 - accuracy: 0.6688 Epoch 51: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 38s 24ms/step - loss: 1.0145 - accuracy: 0.6688 - val_loss: 0.8169 - val_accuracy: 0.7207 - lr: 0.0010 Epoch 52/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0186 - accuracy: 0.6681 Epoch 52: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 37s 24ms/step - loss: 1.0186 - accuracy: 0.6681 - val_loss: 0.7924 - val_accuracy: 0.7355 - lr: 0.0010 Epoch 53/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.0113 - accuracy: 0.6686 Epoch 53: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0113 - accuracy: 0.6685 - val_loss: 0.7319 - val_accuracy: 0.7498 - lr: 0.0010 Epoch 54/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0109 - accuracy: 0.6712 Epoch 54: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0109 - accuracy: 0.6712 - val_loss: 1.0983 - val_accuracy: 0.6648 - lr: 0.0010 Epoch 55/100 1561/1563 [============================>.] - ETA: 0s - loss: 1.0129 - accuracy: 0.6713 Epoch 55: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0129 - accuracy: 0.6713 - val_loss: 0.8429 - val_accuracy: 0.7179 - lr: 0.0010 Epoch 56/100 1561/1563 [============================>.] - ETA: 0s - loss: 0.9992 - accuracy: 0.6728 Epoch 56: val_accuracy did not improve from 0.76160 1563/1563 [==============================] - 37s 23ms/step - loss: 0.9993 - accuracy: 0.6729 - val_loss: 0.8258 - val_accuracy: 0.7244 - lr: 0.0010 Epoch 57/100 1563/1563 [==============================] - ETA: 0s - loss: 0.9956 - accuracy: 0.6758 Epoch 57: val_accuracy improved from 0.76160 to 0.77110, saving model to /content/drive/MyDrive/checkpoints/cifar10_best_model.h5 1563/1563 [==============================] - 37s 24ms/step - loss: 0.9956 - accuracy: 0.6758 - val_loss: 0.6726 - val_accuracy: 0.7711 - lr: 0.0010 Epoch 58/100 1563/1563 [==============================] - ETA: 0s - loss: 1.0006 - accuracy: 0.6737 Epoch 58: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 36s 23ms/step - loss: 1.0006 - accuracy: 0.6737 - val_loss: 0.8017 - val_accuracy: 0.7251 - lr: 0.0010 Epoch 59/100 1563/1563 [==============================] - ETA: 0s - loss: 0.9981 - accuracy: 0.6732 Epoch 59: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 37s 24ms/step - loss: 0.9981 - accuracy: 0.6732 - val_loss: 0.7429 - val_accuracy: 0.7520 - lr: 0.0010 Epoch 60/100 1562/1563 [============================>.] - ETA: 0s - loss: 0.9920 - accuracy: 0.6789 Epoch 60: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 38s 24ms/step - loss: 0.9920 - accuracy: 0.6789 - val_loss: 0.7769 - val_accuracy: 0.7382 - lr: 0.0010 Epoch 61/100 1563/1563 [==============================] - ETA: 0s - loss: 0.9926 - accuracy: 0.6794 Epoch 61: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 36s 23ms/step - loss: 0.9926 - accuracy: 0.6794 - val_loss: 0.8365 - val_accuracy: 0.7146 - lr: 0.0010 Epoch 62/100 1563/1563 [==============================] - ETA: 0s - loss: 0.9928 - accuracy: 0.6751 Epoch 62: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 37s 23ms/step - loss: 0.9928 - accuracy: 0.6751 - val_loss: 0.7080 - val_accuracy: 0.7619 - lr: 0.0010 Epoch 63/100 1562/1563 [============================>.] - ETA: 0s - loss: 0.9901 - accuracy: 0.6787 Epoch 63: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 36s 23ms/step - loss: 0.9901 - accuracy: 0.6787 - val_loss: 0.8295 - val_accuracy: 0.7269 - lr: 0.0010 Epoch 64/100 1562/1563 [============================>.] - ETA: 0s - loss: 0.9890 - accuracy: 0.6796 Epoch 64: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 37s 24ms/step - loss: 0.9888 - accuracy: 0.6796 - val_loss: 0.7512 - val_accuracy: 0.7521 - lr: 0.0010 Epoch 65/100 1563/1563 [==============================] - ETA: 0s - loss: 0.9867 - accuracy: 0.6767 Epoch 65: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 35s 23ms/step - loss: 0.9867 - accuracy: 0.6767 - val_loss: 0.7518 - val_accuracy: 0.7448 - lr: 0.0010 Epoch 66/100 1562/1563 [============================>.] - ETA: 0s - loss: 0.9789 - accuracy: 0.6827 Epoch 66: val_accuracy did not improve from 0.77110 1563/1563 [==============================] - 37s 24ms/step - loss: 0.9788 - accuracy: 0.6828 - val_loss: 0.7259 - val_accuracy: 0.7563 - lr: 0.0010 Epoch 67/100 1562/1563 [============================>.] - ETA: 0s - loss: 0.9799 - accuracy: 0.6839 Epoch 67: val_accuracy did not improve from 0.77110 Restoring model weights from the end of the best epoch: 57. 1563/1563 [==============================] - 37s 24ms/step - loss: 0.9799 - accuracy: 0.6838 - val_loss: 0.8824 - val_accuracy: 0.7192 - lr: 0.0010 Epoch 67: early stopping
avaliando o treinamento¶
Vamos dar uma olhada n curva de loss e acuracia do treinamento
from matplotlib import pyplot as plt
## exibe history com plot de loss e
#acuracia
def plot_history(history):
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
plot_history(history)
Avaliando o Modelo e Carregando o Melhor Modelo¶
Vamos carregar conjunto de dados de teste e normalize os valores dos pixels:
# Carregando o conjunto de dados de teste CIFAR-10
(x_test, y_test) = cifar10.load_data()[1]
# Normalizando os valores dos pixels para o intervalo [0, 1]
x_test = x_test.astype('float32') / 255.0
# Convertendo os rótulos para vetores one-hot
y_test = to_categorical(y_test, 10)
Avaliação do Modelo¶
Vamos avaliar o modelo nos dados de teste.
Durante o treinamento salvamos no google drive o modelo para garantir que estamos utilizando o modelo com melhor desempenho:
# Carregando o melhor modelo salvo durante o treinamento
best_model = tf.keras.models.load_model(model_save_path)
# Avaliando o melhor modelo salvo nos dados de teste
best_test_loss, best_test_accuracy = best_model.evaluate(x_test, y_test, verbose=2)
print(f'Best test loss: {best_test_loss}')
print(f'Best test accuracy: {best_test_accuracy}')
313/313 - 1s - loss: 0.6726 - accuracy: 0.7711 - 1s/epoch - 4ms/step Best test loss: 0.6726096868515015 Best test accuracy: 0.7710999846458435
fazendo predições nas imagens de teste¶
vamos fazer previsões no conjunto de dados de teste:
import numpy as np
# Fazendo previsões no conjunto de dados de teste
predictions = best_model.predict(x_test)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = np.argmax(y_test, axis=1)
# Nomes das classes CIFAR-10
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
313/313 [==============================] - 1s 2ms/step
Explorando a matriz de confusão¶
A matriz de confusão é uma ferramenta para avaliar a performance de um modelo de classificação. Ela apresenta uma tabela que resume os resultados das previsões do modelo, comparando os rótulos previstos com os rótulos reais.
Como Interpretar:
True Positives (TP)
: Número de previsões corretas da classe positiva.True Negatives (TN)
: Número de previsões corretas da classe negativa.False Positives (FP)
: Número de previsões incorretas onde a classe negativa foi prevista como positiva (erro tipo I).False Negatives (FN)
: Número de previsões incorretas onde a classe positiva foi prevista como negativa (erro tipo II).
A matriz de confusão ajuda a entender não apenas a acurácia do modelo, mas também os tipos de erros que ele comete. Isso pode ser útil para:
- Identificar classes que são frequentemente confundidas.
- Melhorar o modelo ao ajustar os hiperparâmetros ou coletar mais dados para classes específicas.
import seaborn as sns
from sklearn.metrics import confusion_matrix
# Função para plotar a matriz de confusão
def plot_confusion_matrix(true_labels, predicted_labels, class_names):
cm = confusion_matrix(true_labels, predicted_labels)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
# Visualizando a matriz de confusão
print("Matriz de Confusão:")
plot_confusion_matrix(true_classes, predicted_classes, class_names)
Matriz de Confusão:
Exibindo os resultados¶
Vamos plotar algumas imagens para visualizar os resultados
import matplotlib.pyplot as plt
# Função para plotar imagens com rótulos verdadeiros e previstos
def plot_images(images, true_labels, predicted_labels, class_names, num_images=10):
plt.figure(figsize=(15, 15))
for i in range(num_images):
plt.subplot(5, 5, i + 1)
plt.imshow(images[i])
plt.title(f'True: {class_names[true_labels[i]]}\nPred: {class_names[predicted_labels[i]]}')
plt.axis('off')
plt.show()
# Função para plotar imagens de resultados corretos
def plot_correct_predictions(images, true_labels, predicted_labels, class_names):
correct_indices = np.where(predicted_labels == true_labels)[0]
plot_images(images[correct_indices], true_labels[correct_indices], predicted_labels[correct_indices], class_names)
# Função para plotar imagens de resultados incorretos
def plot_incorrect_predictions(images, true_labels, predicted_labels, class_names):
incorrect_indices = np.where(predicted_labels != true_labels)[0]
plot_images(images[incorrect_indices], true_labels[incorrect_indices], predicted_labels[incorrect_indices], class_names)
chamando as funções
# Visualizando algumas imagens de resultados corretos
print("Imagens de Resultados Corretos:")
plot_correct_predictions(x_test, true_classes, predicted_classes, class_names)
# Visualizando algumas imagens de resultados incorretos
print("Imagens de Resultados Incorretos:")
plot_incorrect_predictions(x_test, true_classes, predicted_classes, class_names)
Imagens de Resultados Corretos:
Imagens de Resultados Incorretos: