2021/08/17 vgg_jax
したこと
- Jaxでvggモデル実装
from functools import partial from typing import Callable, Optional, Sequence, Tuple, List, Union import jax.numpy as jnp from flax import linen as nn from .common import ConvBlock, ModuleDef, Sequential ModuleDef = Callable[..., Callable] cfgs: Dict[str, List[Union[str, int]]] = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] } class VGGHead(nn.Module): num_classes: int = 1000 init_weights: bool = True activation: Callable = nn.relu @nn.compact def __call__(self, x): x = nn.Dense(features=4096)(x) x = nn.relu(x) x = nn.Dropout(x) x = nn.Dense(features=4096)(x) x = nn.relu(x) x = nn.Dropout(x) x = nn.Dense(self.num_classes)(x) return x def vgg(cfg: List[Union[int, str]], n_classes: int, conv_block_cls: ModuleDef = ConvBlock, head_layer: ModuleDef = VGGHead, norm_layer: bool = False, pool_fn: Callable = partial( nn.max_pool, window_shape=(2, 2), strides=(2, 2), padding=((1, 1), (1, 1)))) -> Sequntial: layers = [] for v in cfg: if v == "M": layers.append(pool_fn) else: if norm_layer: layers.append(conv_block_cls(v, padding=[(1, 1), (1, 1)])) else: layers.append( conv_block_cls( v, padding=[ (1, 1), (1, 1)], norm_cls=False)) layers.append(partial(jnp.mean, axis=(1, 2))) layers.append(head_layer(num_classes=n_classes)) return Sequential(layers) Vgg11 = partial(vgg, cfg=cfgs["vgg11"], head_layer=VGGHead) Vgg13 = partial(vgg, cfg=cfgs["vgg13"], head_layer=VGGHead) vgg16 = partial(vgg, cfg=cfgs["vgg16"], head_layer=VGGHead) vgg19 = partial(vgg, cfg=cfgs["vgg19"], head_layer=VGGHead)
- vggには2タイプあってconvblockはconv-bn-reluとconv-relu型があるんですね
解像度を落とすにはmaxpoolで落とす。
次mobilenetかDensenet
MBConvの美しいモデル図 github.com
消費
0円