Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent implementation in pure JAX:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import jax import jax.numpy as jnp import numpy as np class MLP: def __init__(self, input_size, hidden_size, output_size, key): key1, key2 = jax.random.split(key) # Initialize weights using Glorot uniform initialization self.linear1_weight = jax.random.uniform(key1, shape=(hidden_size, input_size), minval=-jnp.sqrt(6 / (input_size + hidden_size)), maxval=jnp.sqrt(6 / (input_size + hidden_size)) ) self.linear1_bias = jnp.zeros(hidden_size) self.linear2_weight = jax.random.uniform(key2, shape=(output_size, hidden_size), minval=-jnp.sqrt(6 / (hidden_size + output_size)), maxval=jnp.sqrt(6 / (hidden_size + output_size)) ) self.linear2_bias = jnp.zeros(output_size) def __call__(self, x): # First linear layer x = jnp.dot(x, self.linear1_weight.T) + self.linear1_bias # ReLU activation x = jnp.maximum(x, 0) # Second linear layer x = jnp.dot(x, self.linear2_weight.T) + self.linear2_bias # Softmax return jax.nn.softmax(x, axis=1) # Set a random key for reproducibility key = jax.random.PRNGKey(0) # Create the model model = MLP(10, 20, 30, key) # Generate random input x = jax.random.normal(key, (5, 10)) # Forward pass y = model(x) # Print sum along rows print(list(y.sum(1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import jax import jax.numpy as jnp import numpy as np class MLP: def __init__(self, input_size, hidden_size, output_size, key): key1, key2 = jax.random.split(key) # Initialize weights using Glorot uniform initialization self.linear1_weight = jax.random.uniform(key1, shape=(hidden_size, input_size), minval=-jnp.sqrt(6 / (input_size + hidden_size)), maxval=jnp.sqrt(6 / (input_size + hidden_size)) ) self.linear1_bias = jnp.zeros(hidden_size) self.linear2_weight = jax.random.uniform(key2, shape=(output_size, hidden_size), minval=-jnp.sqrt(6 / (hidden_size + output_size)), maxval=jnp.sqrt(6 / (hidden_size + output_size)) ) self.linear2_bias = jnp.zeros(output_size) def __call__(self, x): # First linear layer x = jnp.dot(x, self.linear1_weight.T) + self.linear1_bias # ReLU activation x = jnp.maximum(x, 0) # Second linear layer x = jnp.dot(x, self.linear2_weight.T) + self.linear2_bias # Softmax return jax.nn.softmax(x, axis=1) # Set a random key for reproducibility key = jax.random.PRNGKey(0) # Create the model model = MLP(10, 20, 30, key) # Generate random input x = jax.random.normal(key, (5, 10)) # Forward pass y = model(x) # Print sum along rows print(list(y.sum(1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import jax import jax.numpy as jnp import numpy as np class MLP: def __init__(self, input_size, hidden_size, output_size, key): key1, key2 = jax.random.split(key) # Initialize weights using Glorot uniform initialization self.linear1_weight = jax.random.uniform(key1, shape=(hidden_size, input_size), minval=-jnp.sqrt(6 / (input_size + hidden_size)), maxval=jnp.sqrt(6 / (input_size + hidden_size)) ) self.linear1_bias = jnp.zeros(hidden_size) self.linear2_weight = jax.random.uniform(key2, shape=(output_size, hidden_size), minval=-jnp.sqrt(6 / (hidden_size + output_size)), maxval=jnp.sqrt(6 / (hidden_size + output_size)) ) self.linear2_bias = jnp.zeros(output_size) def __call__(self, x): # First linear layer x = jnp.dot(x, self.linear1_weight.T) + self.linear1_bias # ReLU activation x = jnp.maximum(x, 0) # Second linear layer x = jnp.dot(x, self.linear2_weight.T) + self.linear2_bias # Softmax return jax.nn.softmax(x, axis=1) # Set a random key for reproducibility key = jax.random.PRNGKey(0) # Create the model model = MLP(10, 20, 30, key) # Generate random input x = jax.random.normal(key, (5, 10)) # Forward pass y = model(x) # Print sum along rows print(list(y.sum(1))) |
[Array(1.0000001, dtype=float32), Array(1., dtype=float32), Array(0.99999994, dtype=float32), Array(0.9999999, dtype=float32), Array(1.0000001, dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: True
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent implementation in pure JAX:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights and biases using Xavier initialization key1, key2 = random.split(key) # First layer weights and biases std1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * std1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases std2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * std2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x) # Set random seed for reproducibility key = random.PRNGKey(0) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute forward pass y = model(x) # Print sum along last dimension print(list(y.sum(axis=1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights and biases using Xavier initialization key1, key2 = random.split(key) # First layer weights and biases std1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * std1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases std2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * std2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x) # Set random seed for reproducibility key = random.PRNGKey(0) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute forward pass y = model(x) # Print sum along last dimension print(list(y.sum(axis=1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights and biases using Xavier initialization key1, key2 = random.split(key) # First layer weights and biases std1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * std1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases std2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * std2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x) # Set random seed for reproducibility key = random.PRNGKey(0) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute forward pass y = model(x) # Print sum along last dimension print(list(y.sum(axis=1))) |
[Array(1., dtype=float32), Array(1., dtype=float32), Array(0.9999999, dtype=float32), Array(0.99999994, dtype=float32), Array(0.9999999, dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent JAX implementation:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights using Glorot (Xavier) initialization key1, key2 = random.split(key) # First layer weights and biases stddev1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * stddev1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases stddev2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * stddev2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x, axis=1) # Create random key for initialization key = random.PRNGKey(0) # Create model model = MLP(10, 20, 30, key) # Create random input x = random.normal(key, (5, 10)) # Forward pass y = model(x) # Print sum along rows print(list(y.sum(1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights using Glorot (Xavier) initialization key1, key2 = random.split(key) # First layer weights and biases stddev1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * stddev1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases stddev2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * stddev2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x, axis=1) # Create random key for initialization key = random.PRNGKey(0) # Create model model = MLP(10, 20, 30, key) # Create random input x = random.normal(key, (5, 10)) # Forward pass y = model(x) # Print sum along rows print(list(y.sum(1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights using Glorot (Xavier) initialization key1, key2 = random.split(key) # First layer weights and biases stddev1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * stddev1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases stddev2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * stddev2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x, axis=1) # Create random key for initialization key = random.PRNGKey(0) # Create model model = MLP(10, 20, 30, key) # Create random input x = random.normal(key, (5, 10)) # Forward pass y = model(x) # Print sum along rows print(list(y.sum(1))) |
[Array(1., dtype=float32), Array(1., dtype=float32), Array(0.9999999, dtype=float32), Array(0.99999994, dtype=float32), Array(0.9999999, dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent implementation in pure JAX:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights using Xavier initialization key1, key2 = random.split(key) # First layer weights and biases std1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * std1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases std2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * std2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x, axis=1) # Set random seed for reproducibility key = random.PRNGKey(42) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute output y = model(x) # Print sum along last dimension print(list(y.sum(1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights using Xavier initialization key1, key2 = random.split(key) # First layer weights and biases std1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * std1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases std2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * std2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x, axis=1) # Set random seed for reproducibility key = random.PRNGKey(42) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute output y = model(x) # Print sum along last dimension print(list(y.sum(1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Initialize weights using Xavier initialization key1, key2 = random.split(key) # First layer weights and biases std1 = jnp.sqrt(2.0 / (input_size + hidden_size)) self.w1 = random.normal(key1, (input_size, hidden_size)) * std1 self.b1 = jnp.zeros((hidden_size,)) # Second layer weights and biases std2 = jnp.sqrt(2.0 / (hidden_size + output_size)) self.w2 = random.normal(key2, (hidden_size, output_size)) * std2 self.b2 = jnp.zeros((output_size,)) def __call__(self, x): # First layer x = jnp.dot(x, self.w1) + self.b1 x = jax.nn.relu(x) # Second layer x = jnp.dot(x, self.w2) + self.b2 # Softmax return jax.nn.softmax(x, axis=1) # Set random seed for reproducibility key = random.PRNGKey(42) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute output y = model(x) # Print sum along last dimension print(list(y.sum(1))) |
[Array(1., dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32), Array(0.99999994, dtype=float32), Array(1., dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Here's the equivalent implementation in JAX:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Use random.PRNGKey for random initialization key1, key2 = random.split(key) # Initialize weights with random values self.linear1_w = random.normal(key1, (input_size, hidden_size)) * 0.01 self.linear1_b = jnp.zeros((hidden_size,)) self.linear2_w = random.normal(key2, (hidden_size, output_size)) * 0.01 self.linear2_b = jnp.zeros((output_size,)) def __call__(self, x): # First layer: linear + ReLU x = jnp.dot(x, self.linear1_w) + self.linear1_b x = jax.nn.relu(x) # Second layer: linear + softmax x = jnp.dot(x, self.linear2_w) + self.linear2_b return jax.nn.softmax(x, axis=1) # Set a seed for reproducibility key = random.PRNGKey(0) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute output y = model(x) print(list(y.sum(1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Use random.PRNGKey for random initialization key1, key2 = random.split(key) # Initialize weights with random values self.linear1_w = random.normal(key1, (input_size, hidden_size)) * 0.01 self.linear1_b = jnp.zeros((hidden_size,)) self.linear2_w = random.normal(key2, (hidden_size, output_size)) * 0.01 self.linear2_b = jnp.zeros((output_size,)) def __call__(self, x): # First layer: linear + ReLU x = jnp.dot(x, self.linear1_w) + self.linear1_b x = jax.nn.relu(x) # Second layer: linear + softmax x = jnp.dot(x, self.linear2_w) + self.linear2_b return jax.nn.softmax(x, axis=1) # Set a seed for reproducibility key = random.PRNGKey(0) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute output y = model(x) print(list(y.sum(1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): # Use random.PRNGKey for random initialization key1, key2 = random.split(key) # Initialize weights with random values self.linear1_w = random.normal(key1, (input_size, hidden_size)) * 0.01 self.linear1_b = jnp.zeros((hidden_size,)) self.linear2_w = random.normal(key2, (hidden_size, output_size)) * 0.01 self.linear2_b = jnp.zeros((output_size,)) def __call__(self, x): # First layer: linear + ReLU x = jnp.dot(x, self.linear1_w) + self.linear1_b x = jax.nn.relu(x) # Second layer: linear + softmax x = jnp.dot(x, self.linear2_w) + self.linear2_b return jax.nn.softmax(x, axis=1) # Set a seed for reproducibility key = random.PRNGKey(0) # Create model and input model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) # Compute output y = model(x) print(list(y.sum(1))) |
[Array(0.9999998, dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32), Array(1.0000001, dtype=float32), Array(1.0000001, dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: True
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False