Feature Pyramid Networks for Object Detection
PDF: https://arxiv.org/pdf/1612.03144v2.pdf
PyTorch代码: https://github.com/shanglianlm0525/PyTorch-Networks
在这里插入图片描述
在这里插入图片描述
PyTorch代码:

import torch
import torch.nn as nn
import torchvision

class FPN(nn.Module):
    def __init__(self):
        super(FPN, self).__init__()
        resnet = torchvision.models.resnet50()
        layers = list(resnet.children())

        self.layer1 = nn.Sequential(*layers[:5])
        self.layer2 = nn.Sequential(*layers[5])
        self.layer3 = nn.Sequential(*layers[6])
        self.layer4 = nn.Sequential(*layers[7])

        self.lateral5 = nn.Conv2d(in_channels=2048,out_channels=256,kernel_size=1)
        self.lateral4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)
        self.lateral3 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)
        self.lateral2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1)

        self.upsample2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.upsample3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.upsample4 = nn.ConvTranspose2d(in_channels=256,out_channels=256, kernel_size=4, stride=2, padding=1)

        self.smooth2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.smooth3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.smooth4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)

        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        c2 = x = self.layer1(x)
        c3 = x = self.layer2(x)
        c4 = x = self.layer3(x)
        c5 = x = self.layer4(x)

        p5 = self.lateral5(c5)
        p4 = self.upsample4(p5)+ self.lateral4(c4)
        p3 = self.upsample3(p4)+ self.lateral3(c3)
        p2 = self.upsample2(p3)+ self.lateral2(c2)

        p4 = self.smooth4(p4)
        p3 = self.smooth3(p3)
        p2 = self.smooth4(p2)
        return p2,p3,p4,p5

if __name__ == '__main__':
    model = FPN()
    print(model)

    input = torch.randn(1, 3, 224, 224)
    p2, p3, p4, p5 = model(input)
    print(p2.shape)
    print(p3.shape)
    print(p4.shape)
    print(p5.shape)
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐