被拉去参加比赛,初赛就被拉爆了,代码留个档
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import (
Conv2D, MaxPool2D, Flatten, Dense, Dropout, Input, BatchNormalization
)
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import matplotlib.pyplot as plt
import sys
#修中文显示
sys.stdout.reconfigure(encoding='utf-8')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
train_dir = r'C:\Users\DELL\Desktop\DogsVsCats\data\train'
img_size = (224, 224)
batch_size = 16
epochs = 30
datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=True,
fill_mode='nearest',
validation_split=0.2
)
train_generator = datagen.flow_from_directory(
train_dir,
target_size=img_size,
batch_size=batch_size,
class_mode='binary',
subset='training',
shuffle=True,
seed=42
)
val_generator = datagen.flow_from_directory(
train_dir,
target_size=img_size,
batch_size=batch_size,
class_mode='binary',
subset='validation',
shuffle=False,
seed=42
)
class_indices = train_generator.class_indices
print("类别映射:", class_indices)
model = Sequential([
Input(shape=(img_size[0], img_size[1], 3)),
Conv2D(16, 3, padding='same', activation='relu'),
BatchNormalization(),
MaxPool2D(2),
Conv2D(32, 3, padding='same', activation='relu'),
BatchNormalization(),
MaxPool2D(2),
Conv2D(64, 3, padding='same', activation='relu'),
BatchNormalization(),
MaxPool2D(2),
# 全连接分类层
Flatten(),
Dense(512, activation='relu'),
Dropout(0.5), # 防止过拟合
Dense(1, activation='sigmoid')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss='binary_crossentropy',
metrics=['accuracy']
)
model.summary()
early_stop = EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
)
lr_scheduler = ReduceLROnPlateau(
monitor='val_loss',
factor=0.8,
patience=4,
min_lr=1e-6
)
history = model.fit(
train_generator,
epochs=epochs,
validation_data=val_generator,
callbacks=[early_stop, lr_scheduler]
)
model.save(r'C:\Users\DELL\Desktop\DogsVsCats\custom_cnn_cat_dog_model_binary.h5')
print("模型已保存")
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('模型准确率')
plt.xlabel('轮数')
plt.ylabel('准确率')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('模型损失')
plt.xlabel('轮数')
plt.ylabel('损失')
plt.legend()
plt.savefig(r'C:\Users\DELL\Desktop\DogsVsCats\custom_cnn_train_history_binary.png')
plt.show()

Comments NOTHING