第一步:CycleGan介绍

        CycleGAN主要用于图像之间的转换,假设有两个不成对的图像X和Y,算法训练去学习一个“自动相互转换”,训练时不需要成对的配对样本,只需要源域和目标域的图像。训练后网络就能实现对图像源域到目标域的迁移。CycleGAN适用于非配对的图像到图像转换,解决了模型需要成对数据进行训练的困难。

第二步:CycleGan网络结构

第三步:模型代码展示

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class ResnetGenerator(nn.Module):
    def __init__(self, ngf=64, img_size=256, light=False):
        super(ResnetGenerator, self).__init__()
        self.light = light

        self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3),
                                       nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),
                                       nn.InstanceNorm2d(ngf),
                                       nn.ReLU(True))

        self.HourGlass1 = HourGlass(ngf, ngf)
        self.HourGlass2 = HourGlass(ngf, ngf)

        # Down-Sampling
        self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1),
                                        nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),
                                        nn.InstanceNorm2d(ngf * 2),
                                        nn.ReLU(True))

        self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1),
                                        nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),
                                        nn.InstanceNorm2d(ngf*4),
                                        nn.ReLU(True))

        # Encoder Bottleneck
        self.EncodeBlock1 = ResnetBlock(ngf*4)
        self.EncodeBlock2 = ResnetBlock(ngf*4)
        self.EncodeBlock3 = ResnetBlock(ngf*4)
        self.EncodeBlock4 = ResnetBlock(ngf*4)

        # Class Activation Map
        self.gap_fc = nn.Linear(ngf*4, 1)
        self.gmp_fc = nn.Linear(ngf*4, 1)
        self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)
        self.relu = nn.ReLU(True)

        # Gamma, Beta block
        if self.light:
            self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4),
                                    nn.ReLU(True),
                                    nn.Linear(ngf*4, ngf*4),
                                    nn.ReLU(True))
        else:
            self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4),
                                    nn.ReLU(True),
                                    nn.Linear(ngf*4, ngf*4),
                                    nn.ReLU(True))

        # Decoder Bottleneck
        self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4)
        self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4)
        self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4)
        self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4)

        # Up-Sampling
        self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2),
                                      nn.ReflectionPad2d(1),
                                      nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),
                                      LIN(ngf*2),
                                      nn.ReLU(True))

        self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2),
                                      nn.ReflectionPad2d(1),
                                      nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),
                                      LIN(ngf),
                                      nn.ReLU(True))

        self.HourGlass3 = HourGlass(ngf, ngf)
        self.HourGlass4 = HourGlass(ngf, ngf, False)

        self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3),
                                        nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),
                                        nn.Tanh())

    def forward(self, x):
        x = self.ConvBlock1(x)
        x = self.HourGlass1(x)
        x = self.HourGlass2(x)

        x = self.DownBlock1(x)
        x = self.DownBlock2(x)

        x = self.EncodeBlock1(x)
        content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
        x = self.EncodeBlock2(x)
        content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
        x = self.EncodeBlock3(x)
        content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
        x = self.EncodeBlock4(x)
        content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)

        gap = F.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
        gap_weight = list(self.gap_fc.parameters())[0]
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = F.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
        gmp_weight = list(self.gmp_fc.parameters())[0]
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = torch.cat([gap_logit, gmp_logit], 1)
        x = torch.cat([gap, gmp], 1)
        x = self.relu(self.conv1x1(x))

        heatmap = torch.sum(x, dim=1, keepdim=True)

        if self.light:
            x_ = F.adaptive_avg_pool2d(x, 1)
            style_features = self.FC(x_.view(x_.shape[0], -1))
        else:
            style_features = self.FC(x.view(x.shape[0], -1))

        x = self.DecodeBlock1(x, content_features4, style_features)
        x = self.DecodeBlock2(x, content_features3, style_features)
        x = self.DecodeBlock3(x, content_features2, style_features)
        x = self.DecodeBlock4(x, content_features1, style_features)

        x = self.UpBlock1(x)
        x = self.UpBlock2(x)

        x = self.HourGlass3(x)
        x = self.HourGlass4(x)
        out = self.ConvBlock2(x)

        return out, cam_logit, heatmap

第四步:运行

第五步:整个工程的内容

代码的下载路径(新窗口打开链接)基于深度学习CycleGan转卡通头像系统

有问题可以私信或者留言,有问必答

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐