Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent code in pure JAX without using any other libraries:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import jax.numpy as jnp from jax import random, grad, jit, vmap class MLP: def __init__(self, input_size, hidden_size, output_size): key = random.PRNGKey(0) self.params = self.init_params(key, input_size, hidden_size, output_size) def init_params(self, key, input_size, hidden_size, output_size): key, subkey = random.split(key) w1 = random.normal(subkey, (input_size, hidden_size)) key, subkey = random.split(key) b1 = random.normal(subkey, (hidden_size,)) key, subkey = random.split(key) w2 = random.normal(subkey, (hidden_size, output_size)) key, subkey = random.split(key) b2 = random.normal(subkey, (output_size,)) return {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2} def __call__(self, params, x): x = jnp.dot(x, params['w1']) + params['b1'] x = jnp.maximum(0, x) # ReLU activation x = jnp.dot(x, params['w2']) + params['b2'] return jax.nn.softmax(x, axis=-1) model = MLP(10, 20, 30) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(model.params, x) print(list(y.flatten().sum(axis=-1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import jax.numpy as jnp from jax import random, grad, jit, vmap class MLP: def __init__(self, input_size, hidden_size, output_size): key = random.PRNGKey(0) self.params = self.init_params(key, input_size, hidden_size, output_size) def init_params(self, key, input_size, hidden_size, output_size): key, subkey = random.split(key) w1 = random.normal(subkey, (input_size, hidden_size)) key, subkey = random.split(key) b1 = random.normal(subkey, (hidden_size,)) key, subkey = random.split(key) w2 = random.normal(subkey, (hidden_size, output_size)) key, subkey = random.split(key) b2 = random.normal(subkey, (output_size,)) return {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2} def __call__(self, params, x): x = jnp.dot(x, params['w1']) + params['b1'] x = jnp.maximum(0, x) # ReLU activation x = jnp.dot(x, params['w2']) + params['b2'] return jax.nn.softmax(x, axis=-1) model = MLP(10, 20, 30) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(model.params, x) print(list(y.flatten().sum(axis=-1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import jax.numpy as jnp from jax import random, grad, jit, vmap class MLP: def __init__(self, input_size, hidden_size, output_size): key = random.PRNGKey(0) self.params = self.init_params(key, input_size, hidden_size, output_size) def init_params(self, key, input_size, hidden_size, output_size): key, subkey = random.split(key) w1 = random.normal(subkey, (input_size, hidden_size)) key, subkey = random.split(key) b1 = random.normal(subkey, (hidden_size,)) key, subkey = random.split(key) w2 = random.normal(subkey, (hidden_size, output_size)) key, subkey = random.split(key) b2 = random.normal(subkey, (output_size,)) return {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2} def __call__(self, params, x): x = jnp.dot(x, params['w1']) + params['b1'] x = jnp.maximum(0, x) # ReLU activation x = jnp.dot(x, params['w2']) + params['b2'] return jax.nn.softmax(x, axis=-1) model = MLP(10, 20, 30) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(model.params, x) print(list(y.flatten().sum(axis=-1))) |
Traceback (most recent call last):
File "/usr/src/app/main.py", line 29, in <module>
y = model(model.params, x)
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/src/app/main.py", line 25, in __call__
return jax.nn.softmax(x, axis=-1)
^^^
NameError: name 'jax' is not defined. Did you mean: 'max'?
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent code in pure JAX without using any other libraries:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random def relu(x): return jnp.maximum(0, x) def softmax(x): exp_x = jnp.exp(x) return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True) def init_mlp(key, input_size, hidden_size, output_size): key1, key2 = random.split(key) weights1 = random.normal(key1, (input_size, hidden_size)) weights2 = random.normal(key2, (hidden_size, output_size)) return (weights1, weights2) def apply_mlp(params, x): weights1, weights2 = params x = jnp.dot(x, weights1) x = relu(x) x = jnp.dot(x, weights2) return softmax(x) key = random.PRNGKey(0) input_size = 10 hidden_size = 20 output_size = 30 params = init_mlp(key, input_size, hidden_size, output_size) x = random.normal(key, (5, 10)) y = apply_mlp(params, x) print(list(y.flatten().sum(axis=1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random def relu(x): return jnp.maximum(0, x) def softmax(x): exp_x = jnp.exp(x) return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True) def init_mlp(key, input_size, hidden_size, output_size): key1, key2 = random.split(key) weights1 = random.normal(key1, (input_size, hidden_size)) weights2 = random.normal(key2, (hidden_size, output_size)) return (weights1, weights2) def apply_mlp(params, x): weights1, weights2 = params x = jnp.dot(x, weights1) x = relu(x) x = jnp.dot(x, weights2) return softmax(x) key = random.PRNGKey(0) input_size = 10 hidden_size = 20 output_size = 30 params = init_mlp(key, input_size, hidden_size, output_size) x = random.normal(key, (5, 10)) y = apply_mlp(params, x) print(list(y.flatten().sum(axis=1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random def relu(x): return jnp.maximum(0, x) def softmax(x): exp_x = jnp.exp(x) return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True) def init_mlp(key, input_size, hidden_size, output_size): key1, key2 = random.split(key) weights1 = random.normal(key1, (input_size, hidden_size)) weights2 = random.normal(key2, (hidden_size, output_size)) return (weights1, weights2) def apply_mlp(params, x): weights1, weights2 = params x = jnp.dot(x, weights1) x = relu(x) x = jnp.dot(x, weights2) return softmax(x) key = random.PRNGKey(0) input_size = 10 hidden_size = 20 output_size = 30 params = init_mlp(key, input_size, hidden_size, output_size) x = random.normal(key, (5, 10)) y = apply_mlp(params, x) print(list(y.flatten().sum(axis=1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | Traceback (most recent call last): File "/usr/src/app/main.py", line 34, in <module> print(list(y.flatten().sum(axis=1))) ^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 226, in sum return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 216, in _reduce_sum return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 96, in _reduction pos_dims, dims = _reduction_dims(a, axis) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 154, in _reduction_dims canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 154, in <genexpr> canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 147, in _canonicalize_axis_allow_named return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 147, in <lambda> return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: axis 1 is out of bounds for array of dimension 1</p> <hr /> <p>For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent code in pure JAX without using any other libraries:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): key, subkey = random.split(key) self.params = { 'linear1': { 'weights': random.normal(subkey, (input_size, hidden_size)), 'bias': jnp.zeros(hidden_size) }, 'linear2': { 'weights': random.normal(subkey, (hidden_size, output_size)), 'bias': jnp.zeros(output_size) } } def __call__(self, params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jax.nn.relu(x) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] return jax.nn.softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(model.params, x) print(list(y.flatten().sum(axis=1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): key, subkey = random.split(key) self.params = { 'linear1': { 'weights': random.normal(subkey, (input_size, hidden_size)), 'bias': jnp.zeros(hidden_size) }, 'linear2': { 'weights': random.normal(subkey, (hidden_size, output_size)), 'bias': jnp.zeros(output_size) } } def __call__(self, params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jax.nn.relu(x) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] return jax.nn.softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(model.params, x) print(list(y.flatten().sum(axis=1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): key, subkey = random.split(key) self.params = { 'linear1': { 'weights': random.normal(subkey, (input_size, hidden_size)), 'bias': jnp.zeros(hidden_size) }, 'linear2': { 'weights': random.normal(subkey, (hidden_size, output_size)), 'bias': jnp.zeros(output_size) } } def __call__(self, params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jax.nn.relu(x) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] return jax.nn.softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(model.params, x) print(list(y.flatten().sum(axis=1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | Traceback (most recent call last): File "/usr/src/app/main.py", line 31, in <module> print(list(y.flatten().sum(axis=1))) ^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 226, in sum return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 216, in _reduce_sum return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 96, in _reduction pos_dims, dims = _reduction_dims(a, axis) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 154, in _reduction_dims canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 154, in <genexpr> canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 147, in _canonicalize_axis_allow_named return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 147, in <lambda> return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: axis 1 is out of bounds for array of dimension 1</p> <hr /> <p>For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent code in pure JAX:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | import jax import jax.numpy as jnp from jax.nn import relu, softmax def mlp_init(input_size, hidden_size, output_size): def init(rng_key): key1, key2 = jax.random.split(rng_key) weights1 = jax.random.normal(key1, (input_size, hidden_size)) * jnp.sqrt(2 / input_size) weights2 = jax.random.normal(key2, (hidden_size, output_size)) * jnp.sqrt(2 / hidden_size) biases1 = jnp.zeros((hidden_size,)) biases2 = jnp.zeros((output_size,)) return (weights1, biases1), (weights2, biases2) return init def mlp_apply(params, x): (weights1, biases1), (weights2, biases2) = params x = jnp.dot(x, weights1) + biases1 x = relu(x) x = jnp.dot(x, weights2) + biases2 return softmax(x, axis=-1) input_size = 10 hidden_size = 20 output_size = 30 rng_key = jax.random.PRNGKey(0) params = mlp_init(input_size, hidden_size, output_size)(rng_key) x = jax.random.normal(rng_key, (5, 10)) y = mlp_apply(params, x) print(list(y.sum(axis=1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | import jax import jax.numpy as jnp from jax.nn import relu, softmax def mlp_init(input_size, hidden_size, output_size): def init(rng_key): key1, key2 = jax.random.split(rng_key) weights1 = jax.random.normal(key1, (input_size, hidden_size)) * jnp.sqrt(2 / input_size) weights2 = jax.random.normal(key2, (hidden_size, output_size)) * jnp.sqrt(2 / hidden_size) biases1 = jnp.zeros((hidden_size,)) biases2 = jnp.zeros((output_size,)) return (weights1, biases1), (weights2, biases2) return init def mlp_apply(params, x): (weights1, biases1), (weights2, biases2) = params x = jnp.dot(x, weights1) + biases1 x = relu(x) x = jnp.dot(x, weights2) + biases2 return softmax(x, axis=-1) input_size = 10 hidden_size = 20 output_size = 30 rng_key = jax.random.PRNGKey(0) params = mlp_init(input_size, hidden_size, output_size)(rng_key) x = jax.random.normal(rng_key, (5, 10)) y = mlp_apply(params, x) print(list(y.sum(axis=1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | import jax import jax.numpy as jnp from jax.nn import relu, softmax def mlp_init(input_size, hidden_size, output_size): def init(rng_key): key1, key2 = jax.random.split(rng_key) weights1 = jax.random.normal(key1, (input_size, hidden_size)) * jnp.sqrt(2 / input_size) weights2 = jax.random.normal(key2, (hidden_size, output_size)) * jnp.sqrt(2 / hidden_size) biases1 = jnp.zeros((hidden_size,)) biases2 = jnp.zeros((output_size,)) return (weights1, biases1), (weights2, biases2) return init def mlp_apply(params, x): (weights1, biases1), (weights2, biases2) = params x = jnp.dot(x, weights1) + biases1 x = relu(x) x = jnp.dot(x, weights2) + biases2 return softmax(x, axis=-1) input_size = 10 hidden_size = 20 output_size = 30 rng_key = jax.random.PRNGKey(0) params = mlp_init(input_size, hidden_size, output_size)(rng_key) x = jax.random.normal(rng_key, (5, 10)) y = mlp_apply(params, x) print(list(y.sum(axis=1))) |
[Array(0.99999994, dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent code in pure JAX without using any other libraries:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import jax.numpy as jnp from jax import random, grad, vmap, jit class MLP: def __init__(self, input_size, hidden_size, output_size): key = random.PRNGKey(0) self.params = self.init_params(key, input_size, hidden_size, output_size) def init_params(self, key, input_size, hidden_size, output_size): key, subkey1, subkey2 = random.split(key, 3) params = { 'linear1': { 'weights': random.normal(subkey1, (input_size, hidden_size)), 'bias': jnp.zeros(hidden_size) }, 'linear2': { 'weights': random.normal(subkey2, (hidden_size, output_size)), 'bias': jnp.zeros(output_size) } } return params def __call__(self, params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jnp.maximum(x, 0) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] return jax.nn.softmax(x, axis=-1) model = MLP(10, 20, 30) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(model.params, x) print(y.sum(axis=1)) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import jax.numpy as jnp from jax import random, grad, vmap, jit class MLP: def __init__(self, input_size, hidden_size, output_size): key = random.PRNGKey(0) self.params = self.init_params(key, input_size, hidden_size, output_size) def init_params(self, key, input_size, hidden_size, output_size): key, subkey1, subkey2 = random.split(key, 3) params = { 'linear1': { 'weights': random.normal(subkey1, (input_size, hidden_size)), 'bias': jnp.zeros(hidden_size) }, 'linear2': { 'weights': random.normal(subkey2, (hidden_size, output_size)), 'bias': jnp.zeros(output_size) } } return params def __call__(self, params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jnp.maximum(x, 0) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] return jax.nn.softmax(x, axis=-1) model = MLP(10, 20, 30) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(model.params, x) print(y.sum(axis=1)) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import jax.numpy as jnp from jax import random, grad, vmap, jit class MLP: def __init__(self, input_size, hidden_size, output_size): key = random.PRNGKey(0) self.params = self.init_params(key, input_size, hidden_size, output_size) def init_params(self, key, input_size, hidden_size, output_size): key, subkey1, subkey2 = random.split(key, 3) params = { 'linear1': { 'weights': random.normal(subkey1, (input_size, hidden_size)), 'bias': jnp.zeros(hidden_size) }, 'linear2': { 'weights': random.normal(subkey2, (hidden_size, output_size)), 'bias': jnp.zeros(output_size) } } return params def __call__(self, params, x): x = jnp.dot(x, params['linear1']['weights']) + params['linear1']['bias'] x = jnp.maximum(x, 0) x = jnp.dot(x, params['linear2']['weights']) + params['linear2']['bias'] return jax.nn.softmax(x, axis=-1) model = MLP(10, 20, 30) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(model.params, x) print(y.sum(axis=1)) |
Traceback (most recent call last):
File "/usr/src/app/main.py", line 32, in <module>
y = model(model.params, x)
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/src/app/main.py", line 28, in __call__
return jax.nn.softmax(x, axis=-1)
^^^
NameError: name 'jax' is not defined. Did you mean: 'max'?
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False