深圳幻海软件技术有限公司 欢迎您!

语义分割系列6-Unet++(pytorch实现)

2023-02-28

目录Unet++网络Denseconnectiondeepsupervision模型复现Unet++数据集准备模型训练训练结果Unet++:《UNet++:ANestedU-NetArchitectureforMedicalImageSegmentation》作者对Unet和Unet++的理解:研习

目录

Unet++网络

Dense connection

deep supervision

模型复现

Unet++

数据集准备

模型训练

训练结果


Unet++:《UNet++: A Nested U-Net Architecture for Medical Image Segmentation》

作者对Unet和Unet++的理解:研习U-Net 

延续前文:语义分割系列2-Unet(pytorch实现)

本文将介绍Unet++网络,在pytorch框架上复现Unet++,并在Camvid数据集上进行训练。


Unet++网络

Dense connection

Unet++继承了Unet的结构,同时又借鉴了DenseNet的稠密连接方式(图1中各种分支)。

图1 Unet++网络结构

作者通过各层之间的稠密连接,互相连接起来,就像Denset那样,前前后后每一个模块互相作用,每一个模块都能看到彼此,那对彼此互相熟悉,分割效果自然就会变好。

在实际分割中,一次次的下采样自然会丢掉一些细节特征,在Unet中是使用skip connection来恢复这些细节,但能否做的更好呢?Unet++就给出了答案,这种稠密连接的方式,每一层都尽量多的保存这种细节信息和全局信息,一层层之间架起桥梁互相沟通,最后共享给最后一层,实现全局信息和局部信息的保留和重构。 

deep supervision

当然,简单的将各个模块连接起来是会实现很好的效果。而我们又能发现,一个Unet++其实是很多个不同深度的Unet++叠加。那么,每一个深度的Unet++是不是就都可以输出一个loss?答案自然是可以的。

所以,作者提出了deep supervision,也就是监督每一个深度的Unet++的输出,通过一定的方式来叠加Loss(比如加权的方式),这样就得到了一个经由1、2、3、4层的Unet++的加权Loss(图2 不同深度Unet++融合)。

图2 不同深度Unet++融合

那么,deep supervision又有什么用呢?-剪枝

既然Unet++由多个不同深度的Unet++叠加起来,那么随意去掉一层,前向传播的梯度不会受到任何变化,但你发现Unet++第三个输出的效果和第四个输出效果差不多时,那就可以毫不犹豫删去4层深度的Unet++。比如,直接删去图3中棕色部分,就可以实现剪枝。这样,就得到了更加轻量化的网络。

图3 剪枝模型

模型复现

Unet++

为了更直观一些,我把代码中的所有符号都和网络结构中对应上了。

  1. import torch
  2. import torch.nn as nn
  3. class ContinusParalleConv(nn.Module):
  4. # 一个连续的卷积模块,包含BatchNorm 在前 和 在后 两种模式
  5. def __init__(self, in_channels, out_channels, pre_Batch_Norm = True):
  6. super(ContinusParalleConv, self).__init__()
  7. self.in_channels = in_channels
  8. self.out_channels = out_channels
  9. if pre_Batch_Norm:
  10. self.Conv_forward = nn.Sequential(
  11. nn.BatchNorm2d(self.in_channels),
  12. nn.ReLU(),
  13. nn.Conv2d(self.in_channels, self.out_channels, 3, padding=1),
  14. nn.BatchNorm2d(out_channels),
  15. nn.ReLU(),
  16. nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1))
  17. else:
  18. self.Conv_forward = nn.Sequential(
  19. nn.Conv2d(self.in_channels, self.out_channels, 3, padding=1),
  20. nn.BatchNorm2d(out_channels),
  21. nn.ReLU(),
  22. nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1),
  23. nn.BatchNorm2d(self.out_channels),
  24. nn.ReLU())
  25. def forward(self, x):
  26. x = self.Conv_forward(x)
  27. return x
  28. class UnetPlusPlus(nn.Module):
  29. def __init__(self, num_classes, deep_supervision=False):
  30. super(UnetPlusPlus, self).__init__()
  31. self.num_classes = num_classes
  32. self.deep_supervision = deep_supervision
  33. self.filters = [64, 128, 256, 512, 1024]
  34. self.CONV3_1 = ContinusParalleConv(512*2, 512, pre_Batch_Norm = True)
  35. self.CONV2_2 = ContinusParalleConv(256*3, 256, pre_Batch_Norm = True)
  36. self.CONV2_1 = ContinusParalleConv(256*2, 256, pre_Batch_Norm = True)
  37. self.CONV1_1 = ContinusParalleConv(128*2, 128, pre_Batch_Norm = True)
  38. self.CONV1_2 = ContinusParalleConv(128*3, 128, pre_Batch_Norm = True)
  39. self.CONV1_3 = ContinusParalleConv(128*4, 128, pre_Batch_Norm = True)
  40. self.CONV0_1 = ContinusParalleConv(64*2, 64, pre_Batch_Norm = True)
  41. self.CONV0_2 = ContinusParalleConv(64*3, 64, pre_Batch_Norm = True)
  42. self.CONV0_3 = ContinusParalleConv(64*4, 64, pre_Batch_Norm = True)
  43. self.CONV0_4 = ContinusParalleConv(64*5, 64, pre_Batch_Norm = True)
  44. self.stage_0 = ContinusParalleConv(3, 64, pre_Batch_Norm = False)
  45. self.stage_1 = ContinusParalleConv(64, 128, pre_Batch_Norm = False)
  46. self.stage_2 = ContinusParalleConv(128, 256, pre_Batch_Norm = False)
  47. self.stage_3 = ContinusParalleConv(256, 512, pre_Batch_Norm = False)
  48. self.stage_4 = ContinusParalleConv(512, 1024, pre_Batch_Norm = False)
  49. self.pool = nn.MaxPool2d(2)
  50. self.upsample_3_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1)
  51. self.upsample_2_1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1)
  52. self.upsample_2_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1)
  53. self.upsample_1_1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1)
  54. self.upsample_1_2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1)
  55. self.upsample_1_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1)
  56. self.upsample_0_1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
  57. self.upsample_0_2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
  58. self.upsample_0_3 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
  59. self.upsample_0_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
  60. # 分割头
  61. self.final_super_0_1 = nn.Sequential(
  62. nn.BatchNorm2d(64),
  63. nn.ReLU(),
  64. nn.Conv2d(64, self.num_classes, 3, padding=1),
  65. )
  66. self.final_super_0_2 = nn.Sequential(
  67. nn.BatchNorm2d(64),
  68. nn.ReLU(),
  69. nn.Conv2d(64, self.num_classes, 3, padding=1),
  70. )
  71. self.final_super_0_3 = nn.Sequential(
  72. nn.BatchNorm2d(64),
  73. nn.ReLU(),
  74. nn.Conv2d(64, self.num_classes, 3, padding=1),
  75. )
  76. self.final_super_0_4 = nn.Sequential(
  77. nn.BatchNorm2d(64),
  78. nn.ReLU(),
  79. nn.Conv2d(64, self.num_classes, 3, padding=1),
  80. )
  81. def forward(self, x):
  82. x_0_0 = self.stage_0(x)
  83. x_1_0 = self.stage_1(self.pool(x_0_0))
  84. x_2_0 = self.stage_2(self.pool(x_1_0))
  85. x_3_0 = self.stage_3(self.pool(x_2_0))
  86. x_4_0 = self.stage_4(self.pool(x_3_0))
  87. x_0_1 = torch.cat([self.upsample_0_1(x_1_0) , x_0_0], 1)
  88. x_0_1 = self.CONV0_1(x_0_1)
  89. x_1_1 = torch.cat([self.upsample_1_1(x_2_0), x_1_0], 1)
  90. x_1_1 = self.CONV1_1(x_1_1)
  91. x_2_1 = torch.cat([self.upsample_2_1(x_3_0), x_2_0], 1)
  92. x_2_1 = self.CONV2_1(x_2_1)
  93. x_3_1 = torch.cat([self.upsample_3_1(x_4_0), x_3_0], 1)
  94. x_3_1 = self.CONV3_1(x_3_1)
  95. x_2_2 = torch.cat([self.upsample_2_2(x_3_1), x_2_0, x_2_1], 1)
  96. x_2_2 = self.CONV2_2(x_2_2)
  97. x_1_2 = torch.cat([self.upsample_1_2(x_2_1), x_1_0, x_1_1], 1)
  98. x_1_2 = self.CONV1_2(x_1_2)
  99. x_1_3 = torch.cat([self.upsample_1_3(x_2_2), x_1_0, x_1_1, x_1_2], 1)
  100. x_1_3 = self.CONV1_3(x_1_3)
  101. x_0_2 = torch.cat([self.upsample_0_2(x_1_1), x_0_0, x_0_1], 1)
  102. x_0_2 = self.CONV0_2(x_0_2)
  103. x_0_3 = torch.cat([self.upsample_0_3(x_1_2), x_0_0, x_0_1, x_0_2], 1)
  104. x_0_3 = self.CONV0_3(x_0_3)
  105. x_0_4 = torch.cat([self.upsample_0_4(x_1_3), x_0_0, x_0_1, x_0_2, x_0_3], 1)
  106. x_0_4 = self.CONV0_4(x_0_4)
  107. if self.deep_supervision:
  108. out_put1 = self.final_super_0_1(x_0_1)
  109. out_put2 = self.final_super_0_2(x_0_2)
  110. out_put3 = self.final_super_0_3(x_0_3)
  111. out_put4 = self.final_super_0_4(x_0_4)
  112. return [out_put1, out_put2, out_put3, out_put4]
  113. else:
  114. return self.final_super_0_4(x_0_4)
  115. if __name__ == "__main__":
  116. print("deep_supervision: False")
  117. deep_supervision = False
  118. device = torch.device('cpu')
  119. inputs = torch.randn((1, 3, 224, 224)).to(device)
  120. model = UnetPlusPlus(num_classes=3, deep_supervision=deep_supervision).to(device)
  121. outputs = model(inputs)
  122. print(outputs.shape)
  123. print("deep_supervision: True")
  124. deep_supervision = True
  125. model = UnetPlusPlus(num_classes=3, deep_supervision=deep_supervision).to(device)
  126. outputs = model(inputs)
  127. for out in outputs:
  128. print(out.shape)

测试结果如下 

数据集准备

数据集使用Camvid数据集,可在CamVid数据集的创建和使用-pytorch中参考构建方法。

  1. # 导入库
  2. import os
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. import torch.nn.functional as F
  8. from torch import optim
  9. from torch.utils.data import Dataset, DataLoader, random_split
  10. from tqdm import tqdm
  11. import warnings
  12. warnings.filterwarnings("ignore")
  13. import os.path as osp
  14. import matplotlib.pyplot as plt
  15. from PIL import Image
  16. import numpy as np
  17. import albumentations as A
  18. from albumentations.pytorch.transforms import ToTensorV2
  19. torch.manual_seed(17)
  20. # 自定义数据集CamVidDataset
  21. class CamVidDataset(torch.utils.data.Dataset):
  22. """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
  23. Args:
  24. images_dir (str): path to images folder
  25. masks_dir (str): path to segmentation masks folder
  26. class_values (list): values of classes to extract from segmentation mask
  27. augmentation (albumentations.Compose): data transfromation pipeline
  28. (e.g. flip, scale, etc.)
  29. preprocessing (albumentations.Compose): data preprocessing
  30. (e.g. noralization, shape manipulation, etc.)
  31. """
  32. def __init__(self, images_dir, masks_dir):
  33. self.transform = A.Compose([
  34. A.Resize(224, 224),
  35. A.HorizontalFlip(),
  36. A.VerticalFlip(),
  37. A.Normalize(),
  38. ToTensorV2(),
  39. ])
  40. self.ids = os.listdir(images_dir)
  41. self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
  42. self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
  43. def __getitem__(self, i):
  44. # read data
  45. image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
  46. mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
  47. image = self.transform(image=image,mask=mask)
  48. return image['image'], image['mask'][:,:,0]
  49. def __len__(self):
  50. return len(self.ids)
  51. # 设置数据集路径
  52. DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
  53. x_train_dir = os.path.join(DATA_DIR, 'train_images')
  54. y_train_dir = os.path.join(DATA_DIR, 'train_labels')
  55. x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
  56. y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
  57. train_dataset = CamVidDataset(
  58. x_train_dir,
  59. y_train_dir,
  60. )
  61. val_dataset = CamVidDataset(
  62. x_valid_dir,
  63. y_valid_dir,
  64. )
  65. train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,drop_last=True)
  66. val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True,drop_last=True)

模型训练

  1. model = UnetPlusPlus(num_classes=33).cuda()
  2. #载入预训练模型
  3. #model.load_state_dict(torch.load(r"checkpoints/Unet++_25.pth"),strict=False)
  1. from d2l import torch as d2l
  2. from tqdm import tqdm
  3. import pandas as pd
  4. #损失函数选用多分类交叉熵损失函数
  5. lossf = nn.CrossEntropyLoss(ignore_index=255)
  6. #选用adam优化器来训练
  7. optimizer = optim.SGD(model.parameters(),lr=0.1)
  8. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1, last_epoch=-1)
  9. #训练50轮
  10. epochs_num = 50
  11. def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
  12. devices=d2l.try_all_gpus()):
  13. timer, num_batches = d2l.Timer(), len(train_iter)
  14. animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
  15. legend=['train loss', 'train acc', 'test acc'])
  16. net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  17. loss_list = []
  18. train_acc_list = []
  19. test_acc_list = []
  20. epochs_list = []
  21. time_list = []
  22. for epoch in range(num_epochs):
  23. # Sum of training loss, sum of training accuracy, no. of examples,
  24. # no. of predictions
  25. metric = d2l.Accumulator(4)
  26. for i, (features, labels) in enumerate(train_iter):
  27. timer.start()
  28. l, acc = d2l.train_batch_ch13(
  29. net, features, labels.long(), loss, trainer, devices)
  30. metric.add(l, acc, labels.shape[0], labels.numel())
  31. timer.stop()
  32. if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
  33. animator.add(epoch + (i + 1) / num_batches,
  34. (metric[0] / metric[2], metric[1] / metric[3],
  35. None))
  36. test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
  37. animator.add(epoch + 1, (None, None, test_acc))
  38. scheduler.step()
  39. print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
  40. #---------保存训练数据---------------
  41. df = pd.DataFrame()
  42. loss_list.append(metric[0] / metric[2])
  43. train_acc_list.append(metric[1] / metric[3])
  44. test_acc_list.append(test_acc)
  45. epochs_list.append(epoch)
  46. time_list.append(timer.sum())
  47. df['epoch'] = epochs_list
  48. df['loss'] = loss_list
  49. df['train_acc'] = train_acc_list
  50. df['test_acc'] = test_acc_list
  51. df['time'] = time_list
  52. df.to_excel("savefile/Unet++_camvid1.xlsx")
  53. #----------------保存模型-------------------
  54. if np.mod(epoch+1, 5) == 0:
  55. torch.save(model.state_dict(), f'checkpoints/Unet++_{epoch+1}.pth')

开始训练

train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

训练结果

Unet++训练结果
文章知识点与官方知识档案匹配,可进一步学习相关知识
OpenCV技能树OpenCV中的深度学习图像分类13287 人正在系统学习中