123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- #!/usr/bin/env python
- # -*- encoding: utf-8 -*-
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- import torch
- import torch.nn as nn
- from .darknet import Darknet
- from .network_blocks import BaseConv
- class YOLOFPN(nn.Module):
- """
- YOLOFPN module. Darknet 53 is the default backbone of this model.
- """
- def __init__(
- self,
- depth=53,
- in_features=["dark3", "dark4", "dark5"],
- ):
- super().__init__()
- self.backbone = Darknet(depth)
- self.in_features = in_features
- # out 1
- self.out1_cbl = self._make_cbl(512, 256, 1)
- self.out1 = self._make_embedding([256, 512], 512 + 256)
- # out 2
- self.out2_cbl = self._make_cbl(256, 128, 1)
- self.out2 = self._make_embedding([128, 256], 256 + 128)
- # upsample
- self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
- def _make_cbl(self, _in, _out, ks):
- return BaseConv(_in, _out, ks, stride=1, act="lrelu")
- def _make_embedding(self, filters_list, in_filters):
- m = nn.Sequential(
- *[
- self._make_cbl(in_filters, filters_list[0], 1),
- self._make_cbl(filters_list[0], filters_list[1], 3),
- self._make_cbl(filters_list[1], filters_list[0], 1),
- self._make_cbl(filters_list[0], filters_list[1], 3),
- self._make_cbl(filters_list[1], filters_list[0], 1),
- ]
- )
- return m
- def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
- with open(filename, "rb") as f:
- state_dict = torch.load(f, map_location="cpu")
- print("loading pretrained weights...")
- self.backbone.load_state_dict(state_dict)
- def forward(self, inputs):
- """
- Args:
- inputs (Tensor): input image.
- Returns:
- Tuple[Tensor]: FPN output features..
- """
- # backbone
- out_features = self.backbone(inputs)
- x2, x1, x0 = [out_features[f] for f in self.in_features]
- # yolo branch 1
- x1_in = self.out1_cbl(x0)
- x1_in = self.upsample(x1_in)
- x1_in = torch.cat([x1_in, x1], 1)
- out_dark4 = self.out1(x1_in)
- # yolo branch 2
- x2_in = self.out2_cbl(out_dark4)
- x2_in = self.upsample(x2_in)
- x2_in = torch.cat([x2_in, x2], 1)
- out_dark3 = self.out2(x2_in)
- outputs = (out_dark3, out_dark4, x0)
- return outputs
|