自构建CNN 猫狗训练模型

HW 发布于 1 天前 10 次阅读 预计阅读时间: 2 分钟


被拉去参加比赛,初赛就被拉爆了,代码留个档

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import (
    Conv2D, MaxPool2D, Flatten, Dense, Dropout, Input, BatchNormalization
)
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import matplotlib.pyplot as plt
import sys

#修中文显示
sys.stdout.reconfigure(encoding='utf-8')

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

train_dir = r'C:\Users\DELL\Desktop\DogsVsCats\data\train'
img_size = (224, 224)
batch_size = 16
epochs = 30

datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2
)

train_generator = datagen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    subset='training',
    shuffle=True,
    seed=42
)

val_generator = datagen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    subset='validation',
    shuffle=False,
    seed=42
)

class_indices = train_generator.class_indices
print("类别映射:", class_indices)

model = Sequential([
    Input(shape=(img_size[0], img_size[1], 3)),
    Conv2D(16, 3, padding='same', activation='relu'),
    BatchNormalization(), 
    MaxPool2D(2),

    Conv2D(32, 3, padding='same', activation='relu'),
    BatchNormalization(),
    MaxPool2D(2),

    Conv2D(64, 3, padding='same', activation='relu'),
    BatchNormalization(),
    MaxPool2D(2),

    # 全连接分类层
    Flatten(),
    Dense(512, activation='relu'),
    Dropout(0.5),  # 防止过拟合
    Dense(1, activation='sigmoid')
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

model.summary()

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

lr_scheduler = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.8,
    patience=4,
    min_lr=1e-6
)

history = model.fit(
    train_generator,
    epochs=epochs,
    validation_data=val_generator,
    callbacks=[early_stop, lr_scheduler]
)

model.save(r'C:\Users\DELL\Desktop\DogsVsCats\custom_cnn_cat_dog_model_binary.h5')
print("模型已保存")

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('模型准确率')
plt.xlabel('轮数')
plt.ylabel('准确率')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('模型损失')
plt.xlabel('轮数')
plt.ylabel('损失')
plt.legend()

plt.savefig(r'C:\Users\DELL\Desktop\DogsVsCats\custom_cnn_train_history_binary.png')
plt.show()
此作者没有提供个人介绍。
最后更新于 2026-03-04