class Discriminator(nn.Module):
    def __init__(self, num_classes):
        """
        initialize

        :param image_size: tuple (3, h, w)
        """
        super().__init__()
        self.num_classes = num_classes

        net = []
        # 1:预先定义
        channels_in = [3+self.num_classes, 64, 128, 256]
        channels_out = [64, 128, 256, 512]
        padding = [1, 1, 1, 0]
        for i in range(len(channels_in)):
            net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i],
                                 kernel_size=4, stride=2, padding=padding[i], bias=False))
            if i == 0:
                net.append(nn.LeakyReLU(0.2))
            else:
                net.append(nn.BatchNorm2d(num_features=channels_out[i]))
                net.append(nn.LeakyReLU(0.2))
                net.append(nn.Dropout(0.5))

        self.classify = nn.Linear(in_features=3*3*512, out_features=num_classes)
        self.softmax = nn.Softmax(dim=1)
        self.disciminate = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=3, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

        self.discriminator = nn.Sequential(*net)

    def forward(self, x, label):
        label = label.unsqueeze(2).unsqueeze(3)
        label = label.repeat(1, 1, x.size(2), x.size(3))
        data = torch.cat(tensors=(x, label), dim=1)
        out = self.discriminator(data)
        out_ = out.view(x.size(0), -1)
        classsify = self.softmax(self.classify(out_))
        real_or_fake = self.sigmoid(self.disciminate(out))
        return real_or_fake.view(x.size(0), -1), classsify

d_out_real_dis, d_out_real_cls = discriminator(image, onehot_label)
real_loss_dis = bce_loss(d_out_real_dis, real_label)
real_loss_cls = nll_loss(d_out_real_cls, label)

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐