2021_08_16_jax_mnist

プログラミング

  • 以下の写経

github.com

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(32, kernel=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(64, kernel=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x```

Jaxのモデルはこんな感じ、美しい

Jax one_hot

print(jax.nn.one_hot(4,10)) ```

[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]

おカネの消費

0円