1.数据介绍及预处理
该数据集来源于git,原始数据集为目标检测数据集,经过分类后得到四类,分别为交流、看书、睡觉以及玩手机,数据读取和预处理代码如下:
def get_data(batch_size=32):
data_dir = 'data'
# 图像尺寸大小
img_height = 224
img_width = 224
# 数据集预处理
datagen = ImageDataGenerator(
# 归一化处理
rescale=1. / 255,
validation_split=0.2
)
# datagen = ImageDataGenerator(rescale=1. / 255,validation_split=0.2)
train_ds = datagen.flow_from_directory(
data_dir,
seed=123,
class_mode='categorical',
target_size=(img_height, img_width),
subset='training',
batch_size=batch_size
)
val_ds = datagen.flow_from_directory(
data_dir,
seed=123,
class_mode='categorical',
target_size=(img_height, img_width),
subset='validation',
batch_size=batch_size)
return train_ds, val_ds
2.普通CNN
搭建一个两层卷积的CNN进行训练,代码如下
import os
import pandas as pd
import tensorflow as tf
from util import get_data
from tensorflow.python.keras.callbacks import EarlyStopping
# import keras_metrics as km
import keras_metrics as km
AUTOTUNE = tf.data.experimental.AUTOTUNE
if __name__ == '__main__':
train_ds, val_ds = get_data()
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (10, 10), activation='relu', input_shape=(224, 224, 3), padding='same'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy', km.f1_score(), km.recall(), km.precision()])
early_stopping = EarlyStopping(
monitor='val_accuracy',
verbose=1,
patience=40,
restore_best_weights=True
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(min_lr=0.00001,
factor=0.5)
history = model.fit(train_ds, epochs=2000, callbacks=[early_stopping, reduce_lr], validation_data=val_ds)
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('cnn_history.csv')
model.save('cnn.h5')
3.VGG16
调用tf的API实现VGG16的搭建训练,使用了imageNet预训练的权重,代码如下:
import os
import pandas as pd
import tensorflow as tf
from util import get_data
from tensorflow.python.keras.callbacks import EarlyStopping
# import keras_metrics as km
import keras_metrics as km
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.applications.resnet import ResNet152
from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras.callbacks import EarlyStopping
import keras_metrics as km
import numpy as np
from keras.callbacks import Callback
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.keras.layers import Dense, Flatten, BatchNormalization, MaxPooling2D
class Metrics(Callback):
def on_train_begin(self, logs={}):
self.val_f1s = []
self.val_recalls = []
self.val_precisions = []
def on_epoch_end(self, epoch, logs={}):
val_predict = (np.asarray(self.model.predict(self.validation_data[0]))).round() ##.model
val_targ = self.validation_data[1] ###.model
_val_f1 = f1_score(val_targ, val_predict, average='micro')
_val_recall = recall_score(val_targ, val_predict, average=None) ###
_val_precision = precision_score(val_targ, val_predict, average=None) ###
self.val_f1s.append(_val_f1)
self.val_recalls.append(_val_recall)
self.val_precisions.append(_val_precision)
print("— val_f1: %f " % _val_f1)
AUTOTUNE = tf.data.experimental.AUTOTUNE
if __name__ == '__main__':
train_ds, val_ds = get_data()
mobile_net = VGG16(input_shape=(224, 224, 3), include_top=False)
# 固定参数
mobile_net.trainable = False
model = Sequential([
mobile_net,
MaxPooling2D(2, 2),
Flatten(),
Dense(1000, activation='relu'),
BatchNormalization(),
Dense(200, activation='relu'),
BatchNormalization(),
Dense(4, activation='softmax')])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy', km.f1_score(), km.recall(), km.precision()])
early_stopping = EarlyStopping(
monitor='val_accuracy',
verbose=1,
patience=40,
restore_best_weights=True
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(min_lr=0.00001,
factor=0.2)
history = model.fit(train_ds, epochs=2000, callbacks=[early_stopping, reduce_lr], validation_data=val_ds)
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('vgg16_history.csv')
model.save('vgg16.h5')
训练过程如下
Epoch 1/2000
30/30 [==============================] - 97s 3s/step - loss: 1.2190 - accuracy: 0.6043 - f1_score: 0.5361 - recall: 0.5300 - precision: 0.5456 - val_loss: 7.5022 - val_accuracy: 0.3008 - val_f1_score: 0.6771 - val_recall: 0.7871 - val_precision: 0.5975
Epoch 2/2000
30/30 [==============================] - 94s 3s/step - loss: 0.1590 - accuracy: 0.9471 - f1_score: 0.6751 - recall: 0.8289 - precision: 0.5698 - val_loss: 5.9977 - val_accuracy: 0.3178 - val_f1_score: 0.7130 - val_recall: 0.8715 - val_precision: 0.6038
Epoch 3/2000
30/30 [==============================] - 94s 3s/step - loss: 0.0506 - accuracy: 0.9903 - f1_score: 0.7162 - recall: 0.8887 - precision: 0.6000 - val_loss: 2.7450 - val_accuracy: 0.5297 - val_f1_score: 0.7537 - val_recall: 0.9138 - val_precision: 0.6414
Epoch 4/2000
30/30 [==============================] - 94s 3s/step - loss: 0.0202 - accuracy: 0.9965 - f1_score: 0.7629 - recall: 0.9216 - precision: 0.6510 - val_loss: 1.6227 - val_accuracy: 0.6483 - val_f1_score: 0.7910 - val_recall: 0.9348 - val_precision: 0.6855
Epoch 5/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0247 - accuracy: 0.9989 - f1_score: 0.7976 - recall: 0.9386 - precision: 0.6935 - val_loss: 1.1019 - val_accuracy: 0.7246 - val_f1_score: 0.8181 - val_recall: 0.9466 - val_precision: 0.7204
Epoch 6/2000
30/30 [==============================] - 177s 6s/step - loss: 0.0093 - accuracy: 1.0000 - f1_score: 0.8251 - recall: 0.9497 - precision: 0.7294 - val_loss: 0.8524 - val_accuracy: 0.7881 - val_f1_score: 0.8414 - val_recall: 0.9553 - val_precision: 0.7518
Epoch 7/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0063 - accuracy: 1.0000 - f1_score: 0.8470 - recall: 0.9572 - precision: 0.7597 - val_loss: 0.7584 - val_accuracy: 0.8136 - val_f1_score: 0.8602 - val_recall: 0.9614 - val_precision: 0.7783
Epoch 8/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0035 - accuracy: 1.0000 - f1_score: 0.8652 - recall: 0.9628 - precision: 0.7856 - val_loss: 0.7641 - val_accuracy: 0.7966 - val_f1_score: 0.8751 - val_recall: 0.9658 - val_precision: 0.7999
Epoch 9/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0039 - accuracy: 1.0000 - f1_score: 0.8790 - recall: 0.9668 - precision: 0.8058 - val_loss: 0.8206 - val_accuracy: 0.7881 - val_f1_score: 0.8870 - val_recall: 0.9691 - val_precision: 0.8177
Epoch 10/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0037 - accuracy: 1.0000 - f1_score: 0.8901 - recall: 0.9700 - precision: 0.8223 - val_loss: 0.8147 - val_accuracy: 0.8051 - val_f1_score: 0.8971 - val_recall: 0.9721 - val_precision: 0.8329
Epoch 11/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0028 - accuracy: 1.0000 - f1_score: 0.8996 - recall: 0.9726 - precision: 0.8368 - val_loss: 0.8106 - val_accuracy: 0.8008 - val_f1_score: 0.9054 - val_recall: 0.9741 - val_precision: 0.8458
Epoch 12/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0014 - accuracy: 1.0000 - f1_score: 0.9076 - recall: 0.9748 - precision: 0.8492 - val_loss: 0.8184 - val_accuracy: 0.8051 - val_f1_score: 0.9125 - val_recall: 0.9761 - val_precision: 0.8566
Epoch 13/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0021 - accuracy: 1.0000 - f1_score: 0.9143 - recall: 0.9766 - precision: 0.8595 - val_loss: 0.7995 - val_accuracy: 0.8093 - val_f1_score: 0.9184 - val_recall: 0.9777 - val_precision: 0.8660
Epoch 14/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0015 - accuracy: 1.0000 - f1_score: 0.9201 - recall: 0.9781 - precision: 0.8687 - val_loss: 0.8233 - val_accuracy: 0.8008 - val_f1_score: 0.9237 - val_recall: 0.9790 - val_precision: 0.8744
Epoch 15/2000
30/30 [==============================] - 94s 3s/step - loss: 9.9890e-04 - accuracy: 1.0000 - f1_score: 0.9251 - recall: 0.9794 - precision: 0.8765 - val_loss: 0.8341 - val_accuracy: 0.8093 - val_f1_score: 0.9283 - val_recall: 0.9802 - val_precision: 0.8816
Epoch 16/2000
30/30 [==============================] - 94s 3s/step - loss: 0.0012 - accuracy: 1.0000 - f1_score: 0.9295 - recall: 0.9805 - precision: 0.8835 - val_loss: 0.8482 - val_accuracy: 0.8136 - val_f1_score: 0.9321 - val_recall: 0.9812 - val_precision: 0.8877
Epoch 17/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0018 - accuracy: 1.0000 - f1_score: 0.9333 - recall: 0.9815 - precision: 0.8897 - val_loss: 0.8538 - val_accuracy: 0.8093 - val_f1_score: 0.9356 - val_recall: 0.9819 - val_precision: 0.8935
Epoch 18/2000
30/30 [==============================] - 93s 3s/step - loss: 9.5073e-04 - accuracy: 1.0000 - f1_score: 0.9366 - recall: 0.9822 - precision: 0.8950 - val_loss: 0.8545 - val_accuracy: 0.8051 - val_f1_score: 0.9388 - val_recall: 0.9828 - val_precision: 0.8986
Epoch 19/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0011 - accuracy: 1.0000 - f1_score: 0.9396 - recall: 0.9828 - precision: 0.9001 - val_loss: 0.8551 - val_accuracy: 0.8093 - val_f1_score: 0.9416 - val_recall: 0.9833 - val_precision: 0.9032
Epoch 20/2000
30/30 [==============================] - 93s 3s/step - loss: 0.0021 - accuracy: 1.0000 - f1_score: 0.9423 - recall: 0.9835 - precision: 0.9045 - val_loss: 0.8596 - val_accuracy: 0.8093 - val_f1_score: 0.9441 - val_recall: 0.9840 - val_precision: 0.9073
Epoch 21/2000
30/30 [==============================] - 94s 3s/step - loss: 0.0013 - accuracy: 1.0000 - f1_score: 0.9449 - recall: 0.9842 - precision: 0.9086 - val_loss: 0.8619 - val_accuracy: 0.8093 - val_f1_score: 0.9465 - val_recall: 0.9846 - val_precision: 0.9113
Epoch 22/2000
30/30 [==============================] - 106s 4s/step - loss: 7.7389e-04 - accuracy: 1.0000 - f1_score: 0.9471 - recall: 0.9848 - precision: 0.9122 - val_loss: 0.8651 - val_accuracy: 0.8093 - val_f1_score: 0.9486 - val_recall: 0.9852 - val_precision: 0.9147
Epoch 23/2000
30/30 [==============================] - 94s 3s/step - loss: 0.0010 - accuracy: 1.0000 - f1_score: 0.9492 - recall: 0.9853 - precision: 0.9157 - val_loss: 0.8648 - val_accuracy: 0.8093 - val_f1_score: 0.9506 - val_recall: 0.9856 - val_precision: 0.9179
Epoch 24/2000
30/30 [==============================] - 34716s 1197s/step - loss: 9.2160e-04 - accuracy: 1.0000 - f1_score: 0.9511 - recall: 0.9858 - precision: 0.9188 - val_loss: 0.8697 - val_accuracy: 0.8093 - val_f1_score: 0.9524 - val_recall: 0.9862 - val_precision: 0.9208
Epoch 25/2000
30/30 [==============================] - 95s 3s/step - loss: 0.0022 - accuracy: 1.0000 - f1_score: 0.9528 - recall: 0.9862 - precision: 0.9216 - val_loss: 0.8781 - val_accuracy: 0.8093 - val_f1_score: 0.9540 - val_recall: 0.9866 - val_precision: 0.9235
Epoch 26/2000
30/30 [==============================] - 97s 3s/step - loss: 7.5627e-04 - accuracy: 1.0000 - f1_score: 0.9545 - recall: 0.9867 - precision: 0.9243 - val_loss: 0.8821 - val_accuracy: 0.8093 - val_f1_score: 0.9555 - val_recall: 0.9869 - val_precision: 0.9261
Epoch 27/2000
30/30 [==============================] - 96s 3s/step - loss: 0.0011 - accuracy: 1.0000 - f1_score: 0.9560 - recall: 0.9870 - precision: 0.9268 - val_loss: 0.8838 - val_accuracy: 0.8136 - val_f1_score: 0.9570 - val_recall: 0.9873 - val_precision: 0.9285
Epoch 28/2000
30/30 [==============================] - 96s 3s/step - loss: 0.0012 - accuracy: 1.0000 - f1_score: 0.9574 - recall: 0.9874 - precision: 0.9291 - val_loss: 0.8841 - val_accuracy: 0.8136 - val_f1_score: 0.9583 - val_recall: 0.9876 - val_precision: 0.9307
Epoch 29/2000
30/30 [==============================] - 96s 3s/step - loss: 0.0012 - accuracy: 1.0000 - f1_score: 0.9587 - recall: 0.9877 - precision: 0.9313 - val_loss: 0.8835 - val_accuracy: 0.8136 - val_f1_score: 0.9595 - val_recall: 0.9880 - val_precision: 0.9327
Epoch 30/2000
30/30 [==============================] - 96s 3s/step - loss: 6.5889e-04 - accuracy: 1.0000 - f1_score: 0.9599 - recall: 0.9880 - precision: 0.9333 - val_loss: 0.8835 - val_accuracy: 0.8093 - val_f1_score: 0.9607 - val_recall: 0.9882 - val_precision: 0.9347
Epoch 31/2000
30/30 [==============================] - 95s 3s/step - loss: 0.0012 - accuracy: 1.0000 - f1_score: 0.9610 - recall: 0.9883 - precision: 0.9351 - val_loss: 0.8831 - val_accuracy: 0.8093 - val_f1_score: 0.9618 - val_recall: 0.9885 - val_precision: 0.9364
Epoch 32/2000
30/30 [==============================] - 95s 3s/step - loss: 8.9105e-04 - accuracy: 1.0000 - f1_score: 0.9621 - recall: 0.9886 - precision: 0.9370 - val_loss: 0.8833 - val_accuracy: 0.8093 - val_f1_score: 0.9628 - val_recall: 0.9888 - val_precision: 0.9382
Epoch 33/2000
30/30 [==============================] - 98s 3s/step - loss: 9.7961e-04 - accuracy: 1.0000 - f1_score: 0.9631 - recall: 0.9888 - precision: 0.9386 - val_loss: 0.8836 - val_accuracy: 0.8093 - val_f1_score: 0.9638 - val_recall: 0.9890 - val_precision: 0.9397
Epoch 34/2000
30/30 [==============================] - 97s 3s/step - loss: 6.3762e-04 - accuracy: 1.0000 - f1_score: 0.9640 - recall: 0.9891 - precision: 0.9402 - val_loss: 0.8847 - val_accuracy: 0.8093 - val_f1_score: 0.9646 - val_recall: 0.9892 - val_precision: 0.9412
Epoch 35/2000
30/30 [==============================] - 96s 3s/step - loss: 0.0014 - accuracy: 1.0000 - f1_score: 0.9649 - recall: 0.9893 - precision: 0.9417 - val_loss: 0.8852 - val_accuracy: 0.8093 - val_f1_score: 0.9655 - val_recall: 0.9895 - val_precision: 0.9426
Epoch 36/2000
30/30 [==============================] - 96s 3s/step - loss: 8.7828e-04 - accuracy: 1.0000 - f1_score: 0.9657 - recall: 0.9895 - precision: 0.9431 - val_loss: 0.8859 - val_accuracy: 0.8093 - val_f1_score: 0.9663 - val_recall: 0.9897 - val_precision: 0.9440
Epoch 37/2000
30/30 [==============================] - 96s 3s/step - loss: 9.5012e-04 - accuracy: 1.0000 - f1_score: 0.9665 - recall: 0.9897 - precision: 0.9444 - val_loss: 0.8865 - val_accuracy: 0.8093 - val_f1_score: 0.9671 - val_recall: 0.9898 - val_precision: 0.9453
Epoch 38/2000
30/30 [==============================] - 97s 3s/step - loss: 5.0426e-04 - accuracy: 1.0000 - f1_score: 0.9673 - recall: 0.9899 - precision: 0.9457 - val_loss: 0.8863 - val_accuracy: 0.8093 - val_f1_score: 0.9677 - val_recall: 0.9900 - val_precision: 0.9465
Epoch 39/2000
30/30 [==============================] - 97s 3s/step - loss: 6.8696e-04 - accuracy: 1.0000 - f1_score: 0.9680 - recall: 0.9900 - precision: 0.9469 - val_loss: 0.8868 - val_accuracy: 0.8093 - val_f1_score: 0.9684 - val_recall: 0.9901 - val_precision: 0.9477
Epoch 40/2000
30/30 [==============================] - 95s 3s/step - loss: 0.0014 - accuracy: 0.9998 - f1_score: 0.9686 - recall: 0.9902 - precision: 0.9480 - val_loss: 0.8868 - val_accuracy: 0.8093 - val_f1_score: 0.9690 - val_recall: 0.9904 - val_precision: 0.9486
Epoch 41/2000
30/30 [==============================] - 100s 3s/step - loss: 0.0105 - accuracy: 0.9949 - f1_score: 0.9691 - recall: 0.9904 - precision: 0.9488 - val_loss: 0.8894 - val_accuracy: 0.8136 - val_f1_score: 0.9696 - val_recall: 0.9905 - val_precision: 0.9495
Epoch 42/2000
30/30 [==============================] - 96s 3s/step - loss: 0.0011 - accuracy: 1.0000 - f1_score: 0.9697 - recall: 0.9905 - precision: 0.9498 - val_loss: 0.8899 - val_accuracy: 0.8136 - val_f1_score: 0.9702 - val_recall: 0.9906 - val_precision: 0.9505
Epoch 43/2000
30/30 [==============================] - 95s 3s/step - loss: 6.8680e-04 - accuracy: 1.0000 - f1_score: 0.9703 - recall: 0.9907 - precision: 0.9508 - val_loss: 0.8896 - val_accuracy: 0.8093 - val_f1_score: 0.9707 - val_recall: 0.9908 - val_precision: 0.9515
Epoch 44/2000
30/30 [==============================] - 95s 3s/step - loss: 8.2077e-04 - accuracy: 1.0000 - f1_score: 0.9709 - recall: 0.9908 - precision: 0.9517 - val_loss: 0.8897 - val_accuracy: 0.8093 - val_f1_score: 0.9713 - val_recall: 0.9909 - val_precision: 0.9524
Epoch 45/2000
30/30 [==============================] - 95s 3s/step - loss: 6.7027e-04 - accuracy: 1.0000 - f1_score: 0.9714 - recall: 0.9909 - precision: 0.9526 - val_loss: 0.8903 - val_accuracy: 0.8093 - val_f1_score: 0.9718 - val_recall: 0.9910 - val_precision: 0.9533
Epoch 46/2000
30/30 [==============================] - 97s 3s/step - loss: 5.5451e-04 - accuracy: 1.0000 - f1_score: 0.9719 - recall: 0.9911 - precision: 0.9535 - val_loss: 0.8895 - val_accuracy: 0.8093 - val_f1_score: 0.9723 - val_recall: 0.9912 - val_precision: 0.9541
Epoch 47/2000
30/30 [==============================] - 94s 3s/step - loss: 0.0011 - accuracy: 1.0000 - f1_score: 0.9724 - recall: 0.9912 - precision: 0.9544 - val_loss: 0.8893 - val_accuracy: 0.8093 - val_f1_score: 0.9728 - val_recall: 0.9913 - val_precision: 0.9549
4.ResNet152
import os
import pandas as pd
import tensorflow as tf
from util import get_data
from tensorflow.python.keras.callbacks import EarlyStopping
# import keras_metrics as km
import keras_metrics as km
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.applications.resnet import ResNet152
from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras.callbacks import EarlyStopping
import keras_metrics as km
import numpy as np
from keras.callbacks import Callback
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.keras.layers import Dense, Flatten, BatchNormalization, MaxPooling2D
class Metrics(Callback):
def on_train_begin(self, logs={}):
self.val_f1s = []
self.val_recalls = []
self.val_precisions = []
def on_epoch_end(self, epoch, logs={}):
val_predict = (np.asarray(self.model.predict(self.validation_data[0]))).round() ##.model
val_targ = self.validation_data[1] ###.model
_val_f1 = f1_score(val_targ, val_predict, average='micro')
_val_recall = recall_score(val_targ, val_predict, average=None) ###
_val_precision = precision_score(val_targ, val_predict, average=None) ###
self.val_f1s.append(_val_f1)
self.val_recalls.append(_val_recall)
self.val_precisions.append(_val_precision)
print("— val_f1: %f " % _val_f1)
AUTOTUNE = tf.data.experimental.AUTOTUNE
if __name__ == '__main__':
train_ds, val_ds = get_data()
mobile_net = ResNet152(input_shape=(224, 224, 3), include_top=False)
# 固定参数
mobile_net.trainable = False
model = Sequential([
mobile_net,
MaxPooling2D(2, 2),
Flatten(),
Dense(1000, activation='relu'),
BatchNormalization(),
Dense(200, activation='relu'),
BatchNormalization(),
Dense(4, activation='softmax')])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy', km.f1_score(), km.recall(), km.precision()])
early_stopping = EarlyStopping(
monitor='val_accuracy',
verbose=1,
patience=80,
restore_best_weights=True
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(min_lr=0.00001,
factor=0.2)
history = model.fit(train_ds, epochs=2000, callbacks=[early_stopping, reduce_lr], validation_data=val_ds)
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('resNet_history.csv')
model.save('resNet.h5')