Shortcuts

Source code for mmcls.models.backbones.swin_transformer

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence

import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList

from ..builder import BACKBONES
from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA
from .base_backbone import BaseBackbone


class SwinBlock(BaseModule):
    """Swin Transformer block.

    Args:
        embed_dims (int): Number of input channels.
        input_resolution (Tuple[int, int]): The resolution of the input feature
            map.
        num_heads (int): Number of attention heads.
        window_size (int, optional): The height and width of the window.
            Defaults to 7.
        shift (bool, optional): Shift the attention window or not.
            Defaults to False.
        ffn_ratio (float, optional): The expansion ratio of feedforward network
            hidden layer channels. Defaults to 4.
        drop_path (float, optional): The drop path rate after attention and
            ffn. Defaults to 0.
        attn_cfgs (dict, optional): The extra config of Shift Window-MSA.
            Defaults to empty dict.
        ffn_cfgs (dict, optional): The extra config of FFN.
            Defaults to empty dict.
        norm_cfg (dict, optional): The config of norm layers.
            Defaults to dict(type='LN').
        auto_pad (bool, optional): Auto pad the feature map to be divisible by
            window_size, Defaults to False.
        init_cfg (dict, optional): The extra config for initialization.
            Default: None.
    """

    def __init__(self,
                 embed_dims,
                 input_resolution,
                 num_heads,
                 window_size=7,
                 shift=False,
                 ffn_ratio=4.,
                 drop_path=0.,
                 attn_cfgs=dict(),
                 ffn_cfgs=dict(),
                 norm_cfg=dict(type='LN'),
                 auto_pad=False,
                 init_cfg=None):

        super(SwinBlock, self).__init__(init_cfg)

        _attn_cfgs = {
            'embed_dims': embed_dims,
            'input_resolution': input_resolution,
            'num_heads': num_heads,
            'shift_size': window_size // 2 if shift else 0,
            'window_size': window_size,
            'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
            'auto_pad': auto_pad,
            **attn_cfgs
        }
        self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
        self.attn = ShiftWindowMSA(**_attn_cfgs)

        _ffn_cfgs = {
            'embed_dims': embed_dims,
            'feedforward_channels': int(embed_dims * ffn_ratio),
            'num_fcs': 2,
            'ffn_drop': 0,
            'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
            'act_cfg': dict(type='GELU'),
            **ffn_cfgs
        }
        self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
        self.ffn = FFN(**_ffn_cfgs)

    def forward(self, x):
        identity = x
        x = self.norm1(x)
        x = self.attn(x)
        x = x + identity

        identity = x
        x = self.norm2(x)
        x = self.ffn(x, identity=identity)
        return x


class SwinBlockSequence(BaseModule):
    """Module with successive Swin Transformer blocks and downsample layer.

    Args:
        embed_dims (int): Number of input channels.
        input_resolution (Tuple[int, int]): The resolution of the input feature
            map.
        depth (int): Number of successive swin transformer blocks.
        num_heads (int): Number of attention heads.
        downsample (bool, optional): Downsample the output of blocks by patch
            merging. Defaults to False.
        downsample_cfg (dict, optional): The extra config of the patch merging
            layer. Defaults to empty dict.
        drop_paths (Sequence[float] | float, optional): The drop path rate in
            each block. Defaults to 0.
        block_cfgs (Sequence[dict] | dict, optional): The extra config of each
            block. Defaults to empty dicts.
        auto_pad (bool, optional): Auto pad the feature map to be divisible by
            window_size, Defaults to False.
        init_cfg (dict, optional): The extra config for initialization.
            Default: None.
    """

    def __init__(self,
                 embed_dims,
                 input_resolution,
                 depth,
                 num_heads,
                 downsample=False,
                 downsample_cfg=dict(),
                 drop_paths=0.,
                 block_cfgs=dict(),
                 auto_pad=False,
                 init_cfg=None):
        super().__init__(init_cfg)

        if not isinstance(drop_paths, Sequence):
            drop_paths = [drop_paths] * depth

        if not isinstance(block_cfgs, Sequence):
            block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]

        self.embed_dims = embed_dims
        self.input_resolution = input_resolution
        self.blocks = ModuleList()
        for i in range(depth):
            _block_cfg = {
                'embed_dims': embed_dims,
                'input_resolution': input_resolution,
                'num_heads': num_heads,
                'shift': False if i % 2 == 0 else True,
                'drop_path': drop_paths[i],
                'auto_pad': auto_pad,
                **block_cfgs[i]
            }
            block = SwinBlock(**_block_cfg)
            self.blocks.append(block)

        if downsample:
            _downsample_cfg = {
                'input_resolution': input_resolution,
                'in_channels': embed_dims,
                'expansion_ratio': 2,
                'norm_cfg': dict(type='LN'),
                **downsample_cfg
            }
            self.downsample = PatchMerging(**_downsample_cfg)
        else:
            self.downsample = None

    def forward(self, x):
        for block in self.blocks:
            x = block(x)

        if self.downsample:
            x = self.downsample(x)
        return x

    @property
    def out_resolution(self):
        if self.downsample:
            return self.downsample.output_resolution
        else:
            return self.input_resolution

    @property
    def out_channels(self):
        if self.downsample:
            return self.downsample.out_channels
        else:
            return self.embed_dims


[docs]@BACKBONES.register_module() class SwinTransformer(BaseBackbone): """ Swin Transformer A PyTorch implement of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_ Inspiration from https://github.com/microsoft/Swin-Transformer Args: arch (str | dict): Swin Transformer architecture Defaults to 'T'. img_size (int | tuple): The size of input image. Defaults to 224. in_channels (int): The num of input channels. Defaults to 3. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults to False. auto_pad (bool): If True, auto pad feature map to fit window_size. Defaults to False. norm_cfg (dict, optional): Config dict for normalization layer at end of backone. Defaults to dict(type='LN') stage_cfgs (Sequence | dict, optional): Extra config dict for each stage. Defaults to empty dict. patch_cfg (dict, optional): Extra config dict for patch embedding. Defaults to empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. Examples: >>> from mmcls.models import SwinTransformer >>> import torch >>> extra_config = dict( >>> arch='tiny', >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, >>> 'expansion_ratio': 3}), >>> auto_pad=True) >>> self = SwinTransformer(**extra_config) >>> inputs = torch.rand(1, 3, 224, 224) >>> output = self.forward(inputs) >>> print(output.shape) (1, 2592, 4) """ arch_zoo = { **dict.fromkeys(['t', 'tiny'], {'embed_dims': 96, 'depths': [2, 2, 6, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['s', 'small'], {'embed_dims': 96, 'depths': [2, 2, 18, 2], 'num_heads': [3, 6, 12, 24]}), **dict.fromkeys(['b', 'base'], {'embed_dims': 128, 'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32]}), **dict.fromkeys(['l', 'large'], {'embed_dims': 192, 'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48]}), } # yapf: disable _version = 2 def __init__(self, arch='T', img_size=224, in_channels=3, drop_rate=0., drop_path_rate=0.1, out_indices=(3, ), use_abs_pos_embed=False, auto_pad=False, norm_cfg=dict(type='LN'), stage_cfgs=dict(), patch_cfg=dict(), init_cfg=None): super(SwinTransformer, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = {'embed_dims', 'depths', 'num_head'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.num_heads = self.arch_settings['num_heads'] self.num_layers = len(self.depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed self.auto_pad = auto_pad _patch_cfg = { 'img_size': img_size, 'in_channels': in_channels, 'embed_dims': self.embed_dims, 'conv_cfg': dict(type='Conv2d', kernel_size=4, stride=4), 'norm_cfg': dict(type='LN'), **patch_cfg } self.patch_embed = PatchEmbed(**_patch_cfg) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution if self.use_abs_pos_embed: self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, self.embed_dims)) self.drop_after_pos = nn.Dropout(p=drop_rate) # stochastic depth total_depth = sum(self.depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] # stochastic depth decay rule self.stages = ModuleList() embed_dims = self.embed_dims input_resolution = patches_resolution for i, (depth, num_heads) in enumerate(zip(self.depths, self.num_heads)): if isinstance(stage_cfgs, Sequence): stage_cfg = stage_cfgs[i] else: stage_cfg = deepcopy(stage_cfgs) downsample = True if i < self.num_layers - 1 else False _stage_cfg = { 'embed_dims': embed_dims, 'depth': depth, 'num_heads': num_heads, 'downsample': downsample, 'input_resolution': input_resolution, 'drop_paths': dpr[:depth], 'auto_pad': auto_pad, **stage_cfg } stage = SwinBlockSequence(**_stage_cfg) self.stages.append(stage) dpr = dpr[depth:] embed_dims = stage.out_channels input_resolution = stage.out_resolution for i in out_indices: if norm_cfg is not None: norm_layer = build_norm_layer(norm_cfg, embed_dims)[1] else: norm_layer = nn.Identity() self.add_module(f'norm{i}', norm_layer)
[docs] def init_weights(self): super(SwinTransformer, self).init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress default init if use pretrained model. return if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02)
[docs] def forward(self, x): x = self.patch_embed(x) if self.use_abs_pos_embed: x = x + self.absolute_pos_embed x = self.drop_after_pos(x) outs = [] for i, stage in enumerate(self.stages): x = stage(x) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') out = norm_layer(x) out = out.view(-1, *stage.out_resolution, stage.out_channels).permute(0, 3, 1, 2).contiguous() outs.append(out) return tuple(outs)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, **kwargs): """load checkpoints.""" # Names of some parameters in has been changed. version = local_metadata.get('version', None) if (version is None or version < 2) and self.__class__ is SwinTransformer: final_stage_num = len(self.stages) - 1 state_dict_keys = list(state_dict.keys()) for k in state_dict_keys: if k.startswith('norm.') or k.startswith('backbone.norm.'): convert_key = k.replace('norm.', f'norm{final_stage_num}.') state_dict[convert_key] = state_dict[k] del state_dict[k] super()._load_from_state_dict(state_dict, prefix, local_metadata, *args, **kwargs)