Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the JAX version of the PyTorch code you provided:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import jax.numpy as jnp from jax import random, vmap from jax.nn import relu, softmax def init_mlp_params(key, input_size, hidden_size, output_size): keys = random.split(key, 2) w1_key, w2_key = keys w1 = random.normal(w1_key, (input_size, hidden_size)) / jnp.sqrt(input_size) b1 = jnp.zeros(hidden_size) w2 = random.normal(w2_key, (hidden_size, output_size)) / jnp.sqrt(hidden_size) b2 = jnp.zeros(output_size) return (w1, b1, w2, b2) def mlp_forward(params, x): w1, b1, w2, b2 = params x = jnp.dot(x, w1) + b1 x = relu(x) x = jnp.dot(x, w2) + b2 return softmax(x) key = random.PRNGKey(0) params = init_mlp_params(key, 10, 20, 30) x = random.normal(key, (5, 10)) y = vmap(mlp_forward, in_axes=(None, 0))(params, x) print(list(y.flatten().sum(1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import jax.numpy as jnp from jax import random, vmap from jax.nn import relu, softmax def init_mlp_params(key, input_size, hidden_size, output_size): keys = random.split(key, 2) w1_key, w2_key = keys w1 = random.normal(w1_key, (input_size, hidden_size)) / jnp.sqrt(input_size) b1 = jnp.zeros(hidden_size) w2 = random.normal(w2_key, (hidden_size, output_size)) / jnp.sqrt(hidden_size) b2 = jnp.zeros(output_size) return (w1, b1, w2, b2) def mlp_forward(params, x): w1, b1, w2, b2 = params x = jnp.dot(x, w1) + b1 x = relu(x) x = jnp.dot(x, w2) + b2 return softmax(x) key = random.PRNGKey(0) params = init_mlp_params(key, 10, 20, 30) x = random.normal(key, (5, 10)) y = vmap(mlp_forward, in_axes=(None, 0))(params, x) print(list(y.flatten().sum(1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import jax.numpy as jnp from jax import random, vmap from jax.nn import relu, softmax def init_mlp_params(key, input_size, hidden_size, output_size): keys = random.split(key, 2) w1_key, w2_key = keys w1 = random.normal(w1_key, (input_size, hidden_size)) / jnp.sqrt(input_size) b1 = jnp.zeros(hidden_size) w2 = random.normal(w2_key, (hidden_size, output_size)) / jnp.sqrt(hidden_size) b2 = jnp.zeros(output_size) return (w1, b1, w2, b2) def mlp_forward(params, x): w1, b1, w2, b2 = params x = jnp.dot(x, w1) + b1 x = relu(x) x = jnp.dot(x, w2) + b2 return softmax(x) key = random.PRNGKey(0) params = init_mlp_params(key, 10, 20, 30) x = random.normal(key, (5, 10)) y = vmap(mlp_forward, in_axes=(None, 0))(params, x) print(list(y.flatten().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | Traceback (most recent call last): File "/usr/src/app/main.py", line 30, in <module> print(list(y.flatten().sum(1))) ^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 226, in sum return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 216, in _reduce_sum return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 96, in _reduction pos_dims, dims = _reduction_dims(a, axis) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 154, in _reduction_dims canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 154, in <genexpr> canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 147, in _canonicalize_axis_allow_named return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/reductions.py", line 147, in <lambda> return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: axis 1 is out of bounds for array of dimension 1</p> <hr /> <p>For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the JAX equivalent of the provided PyTorch code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import jax import jax.numpy as jnp from jax import random # Define the MLP class class MLP: def __init__(self, input_size, hidden_size, output_size): self.linear1 = random.normal(random.PRNGKey(0), (input_size, hidden_size)) self.linear2 = random.normal(random.PRNGKey(1), (hidden_size, output_size)) def __call__(self, x): x = jnp.dot(x, self.linear1) x = jnp.maximum(x, 0) # ReLU activation x = jnp.dot(x, self.linear2) return jax.nn.softmax(x, axis=1) # Create an instance of the MLP model = MLP(10, 20, 30) # Generate some random input data key = random.PRNGKey(2) x = random.normal(key, (5, 10)) # Forward pass through the MLP y = model(x) # Print the sum of the output logits for each example print(list(jnp.sum(y, axis=1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import jax import jax.numpy as jnp from jax import random # Define the MLP class class MLP: def __init__(self, input_size, hidden_size, output_size): self.linear1 = random.normal(random.PRNGKey(0), (input_size, hidden_size)) self.linear2 = random.normal(random.PRNGKey(1), (hidden_size, output_size)) def __call__(self, x): x = jnp.dot(x, self.linear1) x = jnp.maximum(x, 0) # ReLU activation x = jnp.dot(x, self.linear2) return jax.nn.softmax(x, axis=1) # Create an instance of the MLP model = MLP(10, 20, 30) # Generate some random input data key = random.PRNGKey(2) x = random.normal(key, (5, 10)) # Forward pass through the MLP y = model(x) # Print the sum of the output logits for each example print(list(jnp.sum(y, axis=1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import jax import jax.numpy as jnp from jax import random # Define the MLP class class MLP: def __init__(self, input_size, hidden_size, output_size): self.linear1 = random.normal(random.PRNGKey(0), (input_size, hidden_size)) self.linear2 = random.normal(random.PRNGKey(1), (hidden_size, output_size)) def __call__(self, x): x = jnp.dot(x, self.linear1) x = jnp.maximum(x, 0) # ReLU activation x = jnp.dot(x, self.linear2) return jax.nn.softmax(x, axis=1) # Create an instance of the MLP model = MLP(10, 20, 30) # Generate some random input data key = random.PRNGKey(2) x = random.normal(key, (5, 10)) # Forward pass through the MLP y = model(x) # Print the sum of the output logits for each example print(list(jnp.sum(y, axis=1))) |
[Array(0.99999994, dtype=float32), Array(0.9999999, dtype=float32), Array(0.9999999, dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Sure, here's the equivalent code in pure JAX:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import jax.numpy as jnp from jax import random # Initialize RNG key key = random.PRNGKey(0) def init_weights(m, n, key): return random.normal(key, (m, n)) / jnp.sqrt(m) def mlp(params, x): w1, b1, w2, b2 = params x = jnp.dot(x, w1) + b1 x = jnp.maximum(x, 0) # ReLU activation x = jnp.dot(x, w2) + b2 return jax.nn.softmax(x, axis=1) def init_mlp(input_size, hidden_size, output_size, key): keys = random.split(key, 4) return [init_weights(input_size, hidden_size, keys[0]), jnp.zeros(hidden_size), init_weights(hidden_size, output_size, keys[2]), jnp.zeros(output_size)] # Initialize model parameters params = init_mlp(10, 20, 30, key) # Input data x = random.normal(key, (5, 10)) # Forward pass y = mlp(params, x) # Print sum of logits for each example print(list(jnp.sum(y, axis=1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import jax.numpy as jnp from jax import random # Initialize RNG key key = random.PRNGKey(0) def init_weights(m, n, key): return random.normal(key, (m, n)) / jnp.sqrt(m) def mlp(params, x): w1, b1, w2, b2 = params x = jnp.dot(x, w1) + b1 x = jnp.maximum(x, 0) # ReLU activation x = jnp.dot(x, w2) + b2 return jax.nn.softmax(x, axis=1) def init_mlp(input_size, hidden_size, output_size, key): keys = random.split(key, 4) return [init_weights(input_size, hidden_size, keys[0]), jnp.zeros(hidden_size), init_weights(hidden_size, output_size, keys[2]), jnp.zeros(output_size)] # Initialize model parameters params = init_mlp(10, 20, 30, key) # Input data x = random.normal(key, (5, 10)) # Forward pass y = mlp(params, x) # Print sum of logits for each example print(list(jnp.sum(y, axis=1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import jax.numpy as jnp from jax import random # Initialize RNG key key = random.PRNGKey(0) def init_weights(m, n, key): return random.normal(key, (m, n)) / jnp.sqrt(m) def mlp(params, x): w1, b1, w2, b2 = params x = jnp.dot(x, w1) + b1 x = jnp.maximum(x, 0) # ReLU activation x = jnp.dot(x, w2) + b2 return jax.nn.softmax(x, axis=1) def init_mlp(input_size, hidden_size, output_size, key): keys = random.split(key, 4) return [init_weights(input_size, hidden_size, keys[0]), jnp.zeros(hidden_size), init_weights(hidden_size, output_size, keys[2]), jnp.zeros(output_size)] # Initialize model parameters params = init_mlp(10, 20, 30, key) # Input data x = random.normal(key, (5, 10)) # Forward pass y = mlp(params, x) # Print sum of logits for each example print(list(jnp.sum(y, axis=1))) |
Traceback (most recent call last):
File "/usr/src/app/main.py", line 32, in <module>
y = mlp(params, x)
^^^^^^^^^^^^^^
File "/usr/src/app/main.py", line 16, in mlp
return jax.nn.softmax(x, axis=1)
^^^
NameError: name 'jax' is not defined. Did you mean: 'max'?
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent code in pure JAX:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import jax.numpy as jnp from jax import random, jit def relu(x): return jnp.maximum(0, x) def softmax(x, axis=-1): x = x - x.max(axis=axis, keepdims=True) y = jnp.exp(x) return y / y.sum(axis=axis, keepdims=True) class MLP: def __init__(self, input_size, hidden_size, output_size, key): self.key1, self.key2 = random.split(key) self.linear1 = random.normal(self.key1, (input_size, hidden_size)) / jnp.sqrt(input_size) self.linear2 = random.normal(self.key2, (hidden_size, output_size)) / jnp.sqrt(hidden_size) def __call__(self, x): x = jnp.dot(x, self.linear1) x = relu(x) x = jnp.dot(x, self.linear2) return softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.sum(axis=1).flatten())) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import jax.numpy as jnp from jax import random, jit def relu(x): return jnp.maximum(0, x) def softmax(x, axis=-1): x = x - x.max(axis=axis, keepdims=True) y = jnp.exp(x) return y / y.sum(axis=axis, keepdims=True) class MLP: def __init__(self, input_size, hidden_size, output_size, key): self.key1, self.key2 = random.split(key) self.linear1 = random.normal(self.key1, (input_size, hidden_size)) / jnp.sqrt(input_size) self.linear2 = random.normal(self.key2, (hidden_size, output_size)) / jnp.sqrt(hidden_size) def __call__(self, x): x = jnp.dot(x, self.linear1) x = relu(x) x = jnp.dot(x, self.linear2) return softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.sum(axis=1).flatten())) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import jax.numpy as jnp from jax import random, jit def relu(x): return jnp.maximum(0, x) def softmax(x, axis=-1): x = x - x.max(axis=axis, keepdims=True) y = jnp.exp(x) return y / y.sum(axis=axis, keepdims=True) class MLP: def __init__(self, input_size, hidden_size, output_size, key): self.key1, self.key2 = random.split(key) self.linear1 = random.normal(self.key1, (input_size, hidden_size)) / jnp.sqrt(input_size) self.linear2 = random.normal(self.key2, (hidden_size, output_size)) / jnp.sqrt(hidden_size) def __call__(self, x): x = jnp.dot(x, self.linear1) x = relu(x) x = jnp.dot(x, self.linear2) return softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.sum(axis=1).flatten())) |
[Array(0.99999994, dtype=float32), Array(1., dtype=float32), Array(0.99999994, dtype=float32), Array(0.9999999, dtype=float32), Array(1., dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent JAX code for the PyTorch code you provided:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random # Initialize weights and biases key = random.PRNGKey(0) def init_layer(key, input_size, output_size): w_key, b_key = random.split(key) w = random.normal(w_key, (input_size, output_size)) / jnp.sqrt(input_size) b = random.normal(b_key, (output_size,)) return w, b def mlp_forward(params, inputs): w1, b1, w2, b2 = params x = jnp.dot(inputs, w1) + b1 x = jnp.maximum(0, x) # ReLU activation x = jnp.dot(x, w2) + b2 return jax.nn.softmax(x, axis=1) # Initialize model parameters input_size, hidden_size, output_size = 10, 20, 30 key1, key2, key3, key4 = random.split(key, 4) w1, b1 = init_layer(key1, input_size, hidden_size) w2, b2 = init_layer(key2, hidden_size, output_size) params = (w1, b1, w2, b2) # Forward pass x = random.normal(key3, (5, 10)) y = mlp_forward(params, x) # Print the sum of each row print(list(y.sum(axis=1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random # Initialize weights and biases key = random.PRNGKey(0) def init_layer(key, input_size, output_size): w_key, b_key = random.split(key) w = random.normal(w_key, (input_size, output_size)) / jnp.sqrt(input_size) b = random.normal(b_key, (output_size,)) return w, b def mlp_forward(params, inputs): w1, b1, w2, b2 = params x = jnp.dot(inputs, w1) + b1 x = jnp.maximum(0, x) # ReLU activation x = jnp.dot(x, w2) + b2 return jax.nn.softmax(x, axis=1) # Initialize model parameters input_size, hidden_size, output_size = 10, 20, 30 key1, key2, key3, key4 = random.split(key, 4) w1, b1 = init_layer(key1, input_size, hidden_size) w2, b2 = init_layer(key2, hidden_size, output_size) params = (w1, b1, w2, b2) # Forward pass x = random.normal(key3, (5, 10)) y = mlp_forward(params, x) # Print the sum of each row print(list(y.sum(axis=1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import jax import jax.numpy as jnp from jax import random # Initialize weights and biases key = random.PRNGKey(0) def init_layer(key, input_size, output_size): w_key, b_key = random.split(key) w = random.normal(w_key, (input_size, output_size)) / jnp.sqrt(input_size) b = random.normal(b_key, (output_size,)) return w, b def mlp_forward(params, inputs): w1, b1, w2, b2 = params x = jnp.dot(inputs, w1) + b1 x = jnp.maximum(0, x) # ReLU activation x = jnp.dot(x, w2) + b2 return jax.nn.softmax(x, axis=1) # Initialize model parameters input_size, hidden_size, output_size = 10, 20, 30 key1, key2, key3, key4 = random.split(key, 4) w1, b1 = init_layer(key1, input_size, hidden_size) w2, b2 = init_layer(key2, hidden_size, output_size) params = (w1, b1, w2, b2) # Forward pass x = random.normal(key3, (5, 10)) y = mlp_forward(params, x) # Print the sum of each row print(list(y.sum(axis=1))) |
[Array(1.0000001, dtype=float32), Array(1., dtype=float32), Array(0.99999994, dtype=float32), Array(0.99999994, dtype=float32), Array(1., dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: True
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False