基于深度学习神经网络CycleGan转卡通头像系统
基于深度学习CycleGan转卡通头像系统
·
第一步:CycleGan介绍
CycleGAN主要用于图像之间的转换,假设有两个不成对的图像X和Y,算法训练去学习一个“自动相互转换”,训练时不需要成对的配对样本,只需要源域和目标域的图像。训练后网络就能实现对图像源域到目标域的迁移。CycleGAN适用于非配对的图像到图像转换,解决了模型需要成对数据进行训练的困难。
第二步:CycleGan网络结构
第三步:模型代码展示
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class ResnetGenerator(nn.Module):
def __init__(self, ngf=64, img_size=256, light=False):
super(ResnetGenerator, self).__init__()
self.light = light
self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3),
nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),
nn.InstanceNorm2d(ngf),
nn.ReLU(True))
self.HourGlass1 = HourGlass(ngf, ngf)
self.HourGlass2 = HourGlass(ngf, ngf)
# Down-Sampling
self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(ngf * 2),
nn.ReLU(True))
self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(ngf*4),
nn.ReLU(True))
# Encoder Bottleneck
self.EncodeBlock1 = ResnetBlock(ngf*4)
self.EncodeBlock2 = ResnetBlock(ngf*4)
self.EncodeBlock3 = ResnetBlock(ngf*4)
self.EncodeBlock4 = ResnetBlock(ngf*4)
# Class Activation Map
self.gap_fc = nn.Linear(ngf*4, 1)
self.gmp_fc = nn.Linear(ngf*4, 1)
self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)
self.relu = nn.ReLU(True)
# Gamma, Beta block
if self.light:
self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4),
nn.ReLU(True),
nn.Linear(ngf*4, ngf*4),
nn.ReLU(True))
else:
self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4),
nn.ReLU(True),
nn.Linear(ngf*4, ngf*4),
nn.ReLU(True))
# Decoder Bottleneck
self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4)
self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4)
self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4)
self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4)
# Up-Sampling
self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),
LIN(ngf*2),
nn.ReLU(True))
self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),
LIN(ngf),
nn.ReLU(True))
self.HourGlass3 = HourGlass(ngf, ngf)
self.HourGlass4 = HourGlass(ngf, ngf, False)
self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3),
nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),
nn.Tanh())
def forward(self, x):
x = self.ConvBlock1(x)
x = self.HourGlass1(x)
x = self.HourGlass2(x)
x = self.DownBlock1(x)
x = self.DownBlock2(x)
x = self.EncodeBlock1(x)
content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
x = self.EncodeBlock2(x)
content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
x = self.EncodeBlock3(x)
content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
x = self.EncodeBlock4(x)
content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
gap = F.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap_weight = list(self.gap_fc.parameters())[0]
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
gmp = F.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp_weight = list(self.gmp_fc.parameters())[0]
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = F.adaptive_avg_pool2d(x, 1)
style_features = self.FC(x_.view(x_.shape[0], -1))
else:
style_features = self.FC(x.view(x.shape[0], -1))
x = self.DecodeBlock1(x, content_features4, style_features)
x = self.DecodeBlock2(x, content_features3, style_features)
x = self.DecodeBlock3(x, content_features2, style_features)
x = self.DecodeBlock4(x, content_features1, style_features)
x = self.UpBlock1(x)
x = self.UpBlock2(x)
x = self.HourGlass3(x)
x = self.HourGlass4(x)
out = self.ConvBlock2(x)
return out, cam_logit, heatmap
第四步:运行
第五步:整个工程的内容
代码的下载路径(新窗口打开链接):基于深度学习CycleGan转卡通头像系统
有问题可以私信或者留言,有问必答
更多推荐
所有评论(0)