基于pytorch的dcgan代码实现,进行简易图像数据生成
使用numpy创建简易图像数据,使用dcgan生成对抗网络进行图像的自动生成
·
简易实现使用dcgan进行图像数据生成
前言:个人电脑算力有限(2G现显存GPU),现使用numpy自生成1维图像数据,dcgan生成对抗网络训练50轮(花费不到20min),生成fake image已肉眼难分。
生成结果如下:
-
导入
import torchvision import torchvision.transforms as transform from torchvision.datasets import ImageFolder import torchvision.utils as vutils import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.nn import BCELoss from torch.optim import Adam,SGD import matplotlib.pyplot as plt from matplotlib import animation # 用于生成gif图像 import numpy as np
-
训练参数设置
# 参数设置 images_num = 3000 # 生成训练图像数量 image_size = 32 # 生成图像尺寸 image_channel = 1 # 生成图像通道数 batch_size = 128 # 训练数据批次 noise_size = 100 # 随机噪声分布向量长度,用于生成fake image lr = 0.0002 # 学习率 epochs = 50 # 训练轮次 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
批量生成图像训练集
# 1 生成简易图像数据集 def gen_images(): image = np.ones(shape=(image_size,image_size))-np.random.rand(image_size,image_size)/10 min_x = np.random.choice(range(3,8)) max_x = np.random.choice(range(25,30)) values = np.random.rand(1)/10 for i in range(min_x,max_x): row =i;col= int(np.power((i/max_x-0.55)*2,2)*max_x)+3 image[col,row]=values;image[col+2,row]=values;image[col+2,row]=values return image # 创建图像数据集 class Images: def __init__(self): self.len_ = images_num def __getitem__(self,index): return torch.from_numpy(gen_images()).unsqueeze(dim=0).float(),'' def __len__(self): return self.len_ # 生产批数据图像 images = Images() dataloader = DataLoader(dataset=images,batch_size=batch_size,shuffle=True,drop_last=True) # 生成图像显示 real_batch = next(iter(dataloader)) fig,axs = plt.subplots() fig.set_size_inches(w=12,h=8) axs.imshow(vutils.make_grid(tensor=real_batch[0][:64],nrow=8,padding=2).permute(1,2,0)) axs.axis('off');axs.axis('off')
-
构建网络
# 创建生成网络 class G_net(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.ConvTranspose2d(in_channels=noise_size,out_channels=image_size*16,kernel_size=4,stride=1,padding=0,bias=False) ,nn.BatchNorm2d(image_size*16) ,nn.ReLU(inplace=True) ,nn.ConvTranspose2d(in_channels=image_size*16,out_channels=image_size*8,kernel_size=4,stride=2,padding=1,bias=False) ,nn.BatchNorm2d(image_size*8) ,nn.ReLU(inplace=True) ,nn.ConvTranspose2d(in_channels=image_size*8,out_channels=image_size*4,kernel_size=4,stride=2,padding=1,bias=False) ,nn.BatchNorm2d(image_size*4) ,nn.ReLU(inplace=True) ,nn.ConvTranspose2d(in_channels=image_size*4,out_channels=image_channel,kernel_size=4,stride=2,padding=1,bias=False) ,nn.Sigmoid() ) def forward(self,x): return self.main(x) # 创建鉴别网络,设置位生成卷积的逆过程 class D_net(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.Conv2d(in_channels=image_channel,out_channels=image_size*2,kernel_size=4,stride=2,padding=0,bias=False) ,nn.LeakyReLU(negative_slope=0.2,inplace=True) ,nn.Conv2d(in_channels=image_size*2,out_channels=image_size*4,kernel_size=4,stride=2,padding=0,bias=False) ,nn.BatchNorm2d(image_size*4) ,nn.LeakyReLU(negative_slope=0.2,inplace=True) ,nn.Conv2d(in_channels=image_size*4,out_channels=image_size*8,kernel_size=4,stride=2,padding=0,bias=False) ,nn.BatchNorm2d(image_size*8) ,nn.LeakyReLU(negative_slope=0.2,inplace=True) ,nn.Conv2d(in_channels=image_size*8,out_channels=1,kernel_size=2,stride=1,padding=0,bias=False) ,nn.Sigmoid() ) def forward(self,x): return self.main(x)
-
网络参数初始化
# 参数初始化函数 def weight_init(m): class_name = m.__class__.__name__ if class_name.find('Conv') != -1: nn.init.normal_(m.weight.data,mean=0,std=0.02) elif class_name.find('Batch') != -1: nn.init.normal_(m.weight.data,mean=1,std=0.02) nn.init.constant_(m.bias.data,val=0)
-
设置优化器和损失函数
# 设置损失函数、优化器 G = G_net().to(device) D = D_net().to(device) G.apply(weight_init) D.apply(weight_init) loss = BCELoss() G_opt = Adam(params=G.parameters(),lr=lr,betas=(0.5,0.999)) D_opt = Adam(params=D.parameters(),lr=lr,betas=(0.5,0.999)) # G_opt = SGD(params=G.parameters(),lr=lr) # D_opt = SGD(params=D.parameters(),lr=lr)
-
训练
# 训练 # 目标函数损失值 loss_list_d = [] loss_list_g = [] # D网络预测正负样本的概率,目的是生成真假混淆图片,概率值应该越来越接近0.5(鉴别器真假难分) outputs_real_d_list = [] outputs_fake_d_list = [] outputs_g_list = [] # G网络训练过程生成图像记录 g_gen_images = [] for epoch in range(epochs): for real_x,_ in dataloader: # discriminator训练 D_opt.zero_grad() # 鉴别真实数据 real_label = torch.ones(size=(batch_size,)).to(device) outputs_real_d = D(real_x.to(device)) loss_real_D = loss(outputs_real_d.squeeze(),real_label) loss_real_D.backward() # 鉴别生成的假数据 fake_x = G(torch.randn(size=(batch_size,noise_size,1,1)).to(device)) fake_label = torch.zeros(size=(batch_size,)).to(device) outputs_fake_d = D(fake_x.detach()) loss_fake_D = loss(outputs_fake_d.squeeze(),fake_label) loss_fake_D.backward() # 鉴别器总损失 D_loss = loss_real_D + loss_fake_D D_opt.step() G_opt.zero_grad() fake_x = G(torch.randn(size=(batch_size,noise_size,1,1)).to(device)) fake_label = torch.ones(size=(batch_size,)).to(device) outputs_g = D(fake_x) loss_fake_G = loss(outputs_g.squeeze(),fake_label) loss_fake_G.backward() G_opt.step() if epoch % 5 == 0: loss_list_d.append(D_loss.item()) loss_list_g.append(loss_fake_G.item()) outputs_real_d_list.append(outputs_real_d.mean().item()) outputs_fake_d_list.append(outputs_fake_d.mean().item()) outputs_g_list.append(outputs_g.mean().item()) print(f'第{epoch}次Dloss: {D_loss}') print(f'第{epoch}次Gloss: {loss_fake_G}') g_gen_images.append(fake_x) # torch.save(g_gen_images,'./g_gen_images.pkl')
-
训练结果展示
# 查看最终生成数据结果 real_batch = next(iter(dataloader)) fake_x = G(torch.rand(size=(batch_size,noise_size,1,1)).to(device)).cpu() fig,axs = plt.subplots(1,2) fig.set_size_inches(w=12,h=8) axs[0].imshow(torchvision.utils.make_grid(tensor=real_batch[0][:64],nrow=8,padding=2).permute(1,2,0)) axs[1].imshow(torchvision.utils.make_grid(tensor=fake_x[:64],nrow=8,padding=2).permute(1,2,0)) axs[0].axis('off');axs[1].axis('off') axs[0].set_title('real_images');axs[1].set_title('fake_images')
更多推荐
所有评论(0)