2021/08/17 vgg_jax
したこと
- Jaxでvggモデル実装
from functools import partial from typing import Callable, Optional, Sequence, Tuple, List, Union import jax.numpy as jnp from flax import linen as nn from .common import ConvBlock, ModuleDef, Sequential ModuleDef = Callable[..., Callable] cfgs: Dict[str, List[Union[str, int]]] = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] } class VGGHead(nn.Module): num_classes: int = 1000 init_weights: bool = True activation: Callable = nn.relu @nn.compact def __call__(self, x): x = nn.Dense(features=4096)(x) x = nn.relu(x) x = nn.Dropout(x) x = nn.Dense(features=4096)(x) x = nn.relu(x) x = nn.Dropout(x) x = nn.Dense(self.num_classes)(x) return x def vgg(cfg: List[Union[int, str]], n_classes: int, conv_block_cls: ModuleDef = ConvBlock, head_layer: ModuleDef = VGGHead, norm_layer: bool = False, pool_fn: Callable = partial( nn.max_pool, window_shape=(2, 2), strides=(2, 2), padding=((1, 1), (1, 1)))) -> Sequntial: layers = [] for v in cfg: if v == "M": layers.append(pool_fn) else: if norm_layer: layers.append(conv_block_cls(v, padding=[(1, 1), (1, 1)])) else: layers.append( conv_block_cls( v, padding=[ (1, 1), (1, 1)], norm_cls=False)) layers.append(partial(jnp.mean, axis=(1, 2))) layers.append(head_layer(num_classes=n_classes)) return Sequential(layers) Vgg11 = partial(vgg, cfg=cfgs["vgg11"], head_layer=VGGHead) Vgg13 = partial(vgg, cfg=cfgs["vgg13"], head_layer=VGGHead) vgg16 = partial(vgg, cfg=cfgs["vgg16"], head_layer=VGGHead) vgg19 = partial(vgg, cfg=cfgs["vgg19"], head_layer=VGGHead)
- vggには2タイプあってconvblockはconv-bn-reluとconv-relu型があるんですね
解像度を落とすにはmaxpoolで落とす。
次mobilenetかDensenet
MBConvの美しいモデル図 github.com
消費
0円
2021_08_16_jax_mnist
プログラミング
- 以下の写経
class CNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(32, kernel=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(64, kernel=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x``` Jaxのモデルはこんな感じ、美しい Jax one_hot
print(jax.nn.one_hot(4,10)) ```
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
おカネの消費
0円
日記の方も始めます。
ブログでプログラミングのことは発信していましたが簡潔に自分用メモを残したいということで新たに日記も始めようと思います。
日記は毎日更新予定です。