Jupyter AI

12 U-Net案例分析

📅 发表日期: 2024年8月12日

分类: 🤖AI 30 个神经网络

👁️阅读: --

在上一篇文章中,我们深入解析了U-Net的结构,探讨了其编码器和解码器的设计,以及如何通过跳跃连接保持高分辨率特征。现在,我们将通过一个实例来展示如何应用U-Net进行图像分割任务,特别是在医学图像处理中,例如肝脏肿瘤的自动分割。

数据集介绍

我们将使用著名的“肝脏肿瘤分割数据集”进行案例分析。该数据集包含了医学影像(如CT扫描),并提供了相应的标注,标注中将肝脏及其肿瘤部分标出。这是一个经典的二分类问题,其中我们需要分割出肝脏区域以及肝脏内的肿瘤。

实现步骤

接下来,我们将从数据预处理开始,再到模型的构建、训练和评估。

1. 数据预处理

首先,我们需要加载数据并进行预处理。确保图像大小一致,通常我们将其调整为128x128256x256。此外,进行数据增强可以提高模型的泛化能力。

import numpy as np
import cv2
from sklearn.model_selection import train_test_split

def load_data(images_path, masks_path):
    images = []
    masks = []
    for img_name in os.listdir(images_path):
        img = cv2.imread(os.path.join(images_path, img_name))
        mask = cv2.imread(os.path.join(masks_path, img_name), 0)  # 读取为灰度图
        img_resized = cv2.resize(img, (256, 256))  # 调整大小
        mask_resized = cv2.resize(mask, (256, 256))
        
        images.append(img_resized)
        masks.append(mask_resized)

    images = np.array(images) / 255.0  # 归一化
    masks = np.array(masks) / 255.0  # 归一化

    return train_test_split(images, masks, test_size=0.2, random_state=42)

X_train, X_val, y_train, y_val = load_data('path_to_images', 'path_to_masks')

2. U-Net模型构建

模型构建将使用Keras库。以下是U-Net模型的简单实现:

from tensorflow.keras import layers, models

def unet_model(input_size=(256, 256, 3)):
    inputs = layers.Input(input_size)

    # Encoder
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    # Bottom
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    # Bottleneck
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)

    # Decoder
    u5 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c4)
    u5 = layers.concatenate([u5, c3])
    c5 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u5)
    c5 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c5)

    u6 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c2])
    c6 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c6)

    u7 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c1])
    c7 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c7)

    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c7)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

model = unet_model()
model.summary()

3. 模型训练

我们可以使用fit函数训练模型,并设置适当的批量大小和训练轮数。使用EarlyStopping可以防止过拟合。

from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(patience=5, restore_best_weights=True)

history = model.fit(X_train, y_train, 
                    validation_data=(X_val, y_val),
                    epochs=50, 
                    batch_size=16, 
                    callbacks=[early_stopping])

4. 模型评估与结果可视化

在模型训练完成后,我们需要评估其性能。我们可以使用一些常用的指标,如IoU(交并比)Dice Coefficient。以下是示例代码用于绘制训练过程中的损失和准确率曲线:

import matplotlib.pyplot as plt

# 绘制损失曲线
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.show()

# 绘制准确率曲线
plt.plot(history.history['accuracy'], label='train_accuracy')
plt.plot(history.history['val_accuracy'],

🤖AI 30 个神经网络 (滚动鼠标查看)