2021_08_16_jax_mnist
プログラミング
- 以下の写経
class CNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(32, kernel=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(64, kernel=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x``` Jaxのモデルはこんな感じ、美しい Jax one_hot
print(jax.nn.one_hot(4,10)) ```
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
おカネの消費
0円