Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here is the converted JAX code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random def init_params(key, input_size, hidden_size, output_size): k1, k2, k3, k4 = random.split(key, 4) a1 = jnp.sqrt(1.0 / input_size) w1 = random.uniform(k1, (input_size, hidden_size), minval=-a1, maxval=a1) b1 = random.uniform(k2, (hidden_size,), minval=-a1, maxval=a1) a2 = jnp.sqrt(1.0 / hidden_size) w2 = random.uniform(k3, (hidden_size, output_size), minval=-a2, maxval=a2) b2 = random.uniform(k4, (output_size,), minval=-a2, maxval=a2) params = { 'linear1': {'weights': w1, 'bias': b1}, 'linear2': {'weights': w2, 'bias': b2}, } return params def mlp(params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jax.nn.relu(x) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] x = jax.nn.softmax(x, axis=1) return x key = random.PRNGKey(0) model_key, x_key = random.split(key) params = init_params(model_key, 10, 20, 30) x = random.normal(x_key, (5, 10)) y = mlp(params, x) sums = jnp.sum(y, axis=1) print(list(sums)) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random def init_params(key, input_size, hidden_size, output_size): k1, k2, k3, k4 = random.split(key, 4) a1 = jnp.sqrt(1.0 / input_size) w1 = random.uniform(k1, (input_size, hidden_size), minval=-a1, maxval=a1) b1 = random.uniform(k2, (hidden_size,), minval=-a1, maxval=a1) a2 = jnp.sqrt(1.0 / hidden_size) w2 = random.uniform(k3, (hidden_size, output_size), minval=-a2, maxval=a2) b2 = random.uniform(k4, (output_size,), minval=-a2, maxval=a2) params = { 'linear1': {'weights': w1, 'bias': b1}, 'linear2': {'weights': w2, 'bias': b2}, } return params def mlp(params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jax.nn.relu(x) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] x = jax.nn.softmax(x, axis=1) return x key = random.PRNGKey(0) model_key, x_key = random.split(key) params = init_params(model_key, 10, 20, 30) x = random.normal(x_key, (5, 10)) y = mlp(params, x) sums = jnp.sum(y, axis=1) print(list(sums)) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random def init_params(key, input_size, hidden_size, output_size): k1, k2, k3, k4 = random.split(key, 4) a1 = jnp.sqrt(1.0 / input_size) w1 = random.uniform(k1, (input_size, hidden_size), minval=-a1, maxval=a1) b1 = random.uniform(k2, (hidden_size,), minval=-a1, maxval=a1) a2 = jnp.sqrt(1.0 / hidden_size) w2 = random.uniform(k3, (hidden_size, output_size), minval=-a2, maxval=a2) b2 = random.uniform(k4, (output_size,), minval=-a2, maxval=a2) params = { 'linear1': {'weights': w1, 'bias': b1}, 'linear2': {'weights': w2, 'bias': b2}, } return params def mlp(params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jax.nn.relu(x) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] x = jax.nn.softmax(x, axis=1) return x key = random.PRNGKey(0) model_key, x_key = random.split(key) params = init_params(model_key, 10, 20, 30) x = random.normal(x_key, (5, 10)) y = mlp(params, x) sums = jnp.sum(y, axis=1) print(list(sums)) |
[Array(0.99999994, dtype=float32), Array(1.0000001, dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32), Array(0.99999994, dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: True
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False