2021/08/17 vgg_jax

したこと

  • Jaxでvggモデル実装
from functools import partial
from typing import Callable, Optional, Sequence, Tuple, List, Union
import jax.numpy as jnp
from flax import linen as nn
from .common import ConvBlock, ModuleDef, Sequential

ModuleDef = Callable[..., Callable]

cfgs: Dict[str, List[Union[str, int]]] = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}


class VGGHead(nn.Module):
    num_classes: int = 1000
    init_weights: bool = True
    activation: Callable = nn.relu

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=4096)(x)
        x = nn.relu(x)
        x = nn.Dropout(x)
        x = nn.Dense(features=4096)(x)
        x = nn.relu(x)
        x = nn.Dropout(x)
        x = nn.Dense(self.num_classes)(x)
        return x


def vgg(cfg: List[Union[int, str]],
        n_classes: int,
        conv_block_cls: ModuleDef = ConvBlock,
        head_layer: ModuleDef = VGGHead,
        norm_layer: bool = False,
        pool_fn: Callable = partial(
        nn.max_pool, window_shape=(2, 2), strides=(2, 2), padding=((1, 1), (1, 1)))) -> Sequntial:
    layers = []
    for v in cfg:
        if v == "M":
            layers.append(pool_fn)
        else:
            if norm_layer:
                layers.append(conv_block_cls(v, padding=[(1, 1), (1, 1)]))
            else:
                layers.append(
                    conv_block_cls(
                        v, padding=[
                            (1, 1), (1, 1)], norm_cls=False))
    layers.append(partial(jnp.mean, axis=(1, 2)))
    layers.append(head_layer(num_classes=n_classes))
    return Sequential(layers)


Vgg11 = partial(vgg, cfg=cfgs["vgg11"], head_layer=VGGHead)
Vgg13 = partial(vgg, cfg=cfgs["vgg13"], head_layer=VGGHead)
vgg16 = partial(vgg, cfg=cfgs["vgg16"], head_layer=VGGHead)
vgg19 = partial(vgg, cfg=cfgs["vgg19"], head_layer=VGGHead) 
  • vggには2タイプあってconvblockはconv-bn-reluとconv-relu型があるんですね
  • 解像度を落とすにはmaxpoolで落とす。

  • 次mobilenetかDensenet

  • MBConvの美しいモデル図 github.com

消費

0円

2021_08_16_jax_mnist

プログラミング

  • 以下の写経

github.com

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(32, kernel=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(64, kernel=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x```

Jaxのモデルはこんな感じ、美しい

Jax one_hot

print(jax.nn.one_hot(4,10)) ```

[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]

おカネの消費

0円