【笔记】ACGAN:采用辅助分类器使GAN获取图像分类功能
class Discriminator(nn.Module):def __init__(self, num_classes):"""initialize:param image_size: tuple (3, h, w)"""super().__init__()self.num_classes = num_classesnet = []...
·
class Discriminator(nn.Module):
def __init__(self, num_classes):
"""
initialize
:param image_size: tuple (3, h, w)
"""
super().__init__()
self.num_classes = num_classes
net = []
# 1:预先定义
channels_in = [3+self.num_classes, 64, 128, 256]
channels_out = [64, 128, 256, 512]
padding = [1, 1, 1, 0]
for i in range(len(channels_in)):
net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i],
kernel_size=4, stride=2, padding=padding[i], bias=False))
if i == 0:
net.append(nn.LeakyReLU(0.2))
else:
net.append(nn.BatchNorm2d(num_features=channels_out[i]))
net.append(nn.LeakyReLU(0.2))
net.append(nn.Dropout(0.5))
self.classify = nn.Linear(in_features=3*3*512, out_features=num_classes)
self.softmax = nn.Softmax(dim=1)
self.disciminate = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=3, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
self.discriminator = nn.Sequential(*net)
def forward(self, x, label):
label = label.unsqueeze(2).unsqueeze(3)
label = label.repeat(1, 1, x.size(2), x.size(3))
data = torch.cat(tensors=(x, label), dim=1)
out = self.discriminator(data)
out_ = out.view(x.size(0), -1)
classsify = self.softmax(self.classify(out_))
real_or_fake = self.sigmoid(self.disciminate(out))
return real_or_fake.view(x.size(0), -1), classsify
d_out_real_dis, d_out_real_cls = discriminator(image, onehot_label)
real_loss_dis = bce_loss(d_out_real_dis, real_label)
real_loss_cls = nll_loss(d_out_real_cls, label)
更多推荐
所有评论(0)