Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Below is one way to rewrite the code in pure JAX without using any extra libraries: ------------------------------------------------------------ #!/usr/bin/env python # Pure JAX implementation of the MLP example import jax import jax.numpy as jnp from jax import random # Initialize parameters for a fully-connected (linear) layer. def init_linear_params(key, in_dim, out_dim): # Use a simple Xavier/He-like initialization. k1, _ = random.split(key) # Scale by sqrt(2/in_dim) (you can adjust the scaling as needed). weight = random.normal(k1, (in_dim, out_dim)) * jnp.sqrt(2 / in_dim) bias = jnp.zeros((out_dim,)) return weight, bias # Initialize parameters for the two-layer MLP. def init_mlp_params(key, input_size, hidden_size, output_size): key1, key2 = random.split(key) W1, b1 = init_linear_params(key1, input_size, hidden_size) W2, b2 = init_linear_params(key2, hidden_size, output_size) return {"W1": W1, "b1": b1, "W2": W2, "b2": b2} # Define the MLP forward pass. def mlp_forward(params, x): # First linear layer: x @ W1 + b1 x = jnp.dot(x, params["W1"]) + params["b1"] # Apply ReLU nonlinearity. x = jnp.maximum(x, 0) # Second linear layer: x @ W2 + b2 x = jnp.dot(x, params["W2"]) + params["b2"] # Compute the softmax along axis 1. # For numerical stability, subtract the maximum value in each row. x_max = jnp.max(x, axis=1, keepdims=True) exps = jnp.exp(x - x_max) softmax = exps / jnp.sum(exps, axis=1, keepdims=True) return softmax # Main execution: def main(): # Set up a random key. rng = random.PRNGKey(0) # Initialize MLP parameters for an MLP with # input size of 10, hidden size of 20, and output size of 30. params = init_mlp_params(rng, input_size=10, hidden_size=20, output_size=30) # Generate a random input tensor like torch.randn(5, 10) x = random.normal(rng, (5, 10)) # Compute the model output. y = mlp_forward(params, x) # Print the sum over each row. (The original code printed the sum for each row.) # Note: In the PyTorch code there was a call to y.flatten().numpy().sum(1) # which is a bit unusual; here we simply sum each row of y. print(list(y.sum(axis=1))) if __name__ == "__main__": main() ------------------------------------------------------------ Explanation: 1. We import jax and jax.numpy as jnp. 2. The function init_linear_params builds each fully connected layer by drawing random normal weights and zeros for biases. 3. The function init_mlp_params splits the random key to initialize the two layers. 4. The mlp_forward function computes the two linear transformations with a ReLU activation after the first layer and applies a softmax (computed manually for numerical stability) on the final output. 5. In the main function we generate a random input with shape (5, 10), run the model, and then print a list of row-sums of the softmax output. This pure JAX code does not depend on any extra libraries aside from JAX itself.
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | #!/usr/bin/env python # Pure JAX implementation of the MLP example import jax import jax.numpy as jnp from jax import random # Initialize parameters for a fully-connected (linear) layer. def init_linear_params(key, in_dim, out_dim): # Use a simple Xavier/He-like initialization. k1, _ = random.split(key) # Scale by sqrt(2/in_dim) (you can adjust the scaling as needed). weight = random.normal(k1, (in_dim, out_dim)) * jnp.sqrt(2 / in_dim) bias = jnp.zeros((out_dim,)) return weight, bias # Initialize parameters for the two-layer MLP. def init_mlp_params(key, input_size, hidden_size, output_size): key1, key2 = random.split(key) W1, b1 = init_linear_params(key1, input_size, hidden_size) W2, b2 = init_linear_params(key2, hidden_size, output_size) return {"W1": W1, "b1": b1, "W2": W2, "b2": b2} # Define the MLP forward pass. def mlp_forward(params, x): # First linear layer: x @ W1 + b1 x = jnp.dot(x, params["W1"]) + params["b1"] # Apply ReLU nonlinearity. x = jnp.maximum(x, 0) # Second linear layer: x @ W2 + b2 x = jnp.dot(x, params["W2"]) + params["b2"] # Compute the softmax along axis 1. # For numerical stability, subtract the maximum value in each row. x_max = jnp.max(x, axis=1, keepdims=True) exps = jnp.exp(x - x_max) softmax = exps / jnp.sum(exps, axis=1, keepdims=True) return softmax # Main execution: def main(): # Set up a random key. rng = random.PRNGKey(0) # Initialize MLP parameters for an MLP with # input size of 10, hidden size of 20, and output size of 30. params = init_mlp_params(rng, input_size=10, hidden_size=20, output_size=30) # Generate a random input tensor like torch.randn(5, 10) x = random.normal(rng, (5, 10)) # Compute the model output. y = mlp_forward(params, x) # Print the sum over each row. (The original code printed the sum for each row.) # Note: In the PyTorch code there was a call to y.flatten().numpy().sum(1) # which is a bit unusual; here we simply sum each row of y. print(list(y.sum(axis=1))) if __name__ == "__main__": main() |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | #!/usr/bin/env python # Pure JAX implementation of the MLP example import jax import jax.numpy as jnp from jax import random # Initialize parameters for a fully-connected (linear) layer. def init_linear_params(key, in_dim, out_dim): # Use a simple Xavier/He-like initialization. k1, _ = random.split(key) # Scale by sqrt(2/in_dim) (you can adjust the scaling as needed). weight = random.normal(k1, (in_dim, out_dim)) * jnp.sqrt(2 / in_dim) bias = jnp.zeros((out_dim,)) return weight, bias # Initialize parameters for the two-layer MLP. def init_mlp_params(key, input_size, hidden_size, output_size): key1, key2 = random.split(key) W1, b1 = init_linear_params(key1, input_size, hidden_size) W2, b2 = init_linear_params(key2, hidden_size, output_size) return {"W1": W1, "b1": b1, "W2": W2, "b2": b2} # Define the MLP forward pass. def mlp_forward(params, x): # First linear layer: x @ W1 + b1 x = jnp.dot(x, params["W1"]) + params["b1"] # Apply ReLU nonlinearity. x = jnp.maximum(x, 0) # Second linear layer: x @ W2 + b2 x = jnp.dot(x, params["W2"]) + params["b2"] # Compute the softmax along axis 1. # For numerical stability, subtract the maximum value in each row. x_max = jnp.max(x, axis=1, keepdims=True) exps = jnp.exp(x - x_max) softmax = exps / jnp.sum(exps, axis=1, keepdims=True) return softmax # Main execution: def main(): # Set up a random key. rng = random.PRNGKey(0) # Initialize MLP parameters for an MLP with # input size of 10, hidden size of 20, and output size of 30. params = init_mlp_params(rng, input_size=10, hidden_size=20, output_size=30) # Generate a random input tensor like torch.randn(5, 10) x = random.normal(rng, (5, 10)) # Compute the model output. y = mlp_forward(params, x) # Print the sum over each row. (The original code printed the sum for each row.) # Note: In the PyTorch code there was a call to y.flatten().numpy().sum(1) # which is a bit unusual; here we simply sum each row of y. print(list(y.sum(axis=1))) if __name__ == "__main__": main() |
[Array(1., dtype=float32), Array(0.99999994, dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32), Array(1., dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Below is one way to write the same model and computation in pure JAX (using only JAX, no other external libraries): ------------------------------------------------------------ #!/usr/bin/env python3 import jax import jax.numpy as jnp # Initialize parameters for a twoâ€layer fully connected network. def init_mlp_params(key, input_size, hidden_size, output_size): # Split the PRNG key for reproducible independent initializations. key1, key2 = jax.random.split(key, 2) # For the first linear layer, we create weights of shape (input_size, hidden_size) # and biases of shape (hidden_size,). w1 = jax.random.normal(key1, (input_size, hidden_size)) b1 = jnp.zeros((hidden_size,)) # For the second linear layer, we create weights of shape (hidden_size, output_size) # and biases of shape (output_size,). w2 = jax.random.normal(key2, (hidden_size, output_size)) b2 = jnp.zeros((output_size,)) return {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2} # Define the forward pass of the MLP: Linear -> ReLU -> Linear -> Softmax. def mlp_apply(params, x): # First linear layer: x @ W + b x = jnp.dot(x, params['w1']) + params['b1'] # Apply ReLU activation. x = jnp.maximum(x, 0) # Second linear layer. x = jnp.dot(x, params['w2']) + params['b2'] # Compute softmax over the last axis (axis=1) to obtain probabilities. x = jax.nn.softmax(x, axis=1) return x def main(): # Create an initial PRNG key. key = jax.random.PRNGKey(0) # Initialize the MLP parameters with input size=10, hidden size=20, output size=30. params = init_mlp_params(key, 10, 20, 30) # Generate a random input tensor, equivalent to torch.randn(5, 10). key, subkey = jax.random.split(key) x = jax.random.normal(subkey, (5, 10)) # Compute the model's output. y = mlp_apply(params, x) # In PyTorch the code printed: # list(y.flatten().numpy().sum(1)) # Since softmax makes each row sum to 1, we compute the rowâ€sums directly. row_sums = jnp.sum(y, axis=1) # Convert to a Python list and print. print(list(row_sums)) if __name__ == '__main__': main() ------------------------------------------------------------ Explanation: 1. The function init_mlp_params splits a PRNG key into two parts and randomly initializes the weights for each linear layer while setting the biases to 0. 2. The mlp_apply function computes the forward pass: • It performs a dot product (i.e. a linear layer) using jnp.dot. • The ReLU is implemented via jnp.maximum. • The softmax is computed with jax.nn.softmax along axis 1. 3. In the main() function, a random 5×10 input is created with jax.random.normal, the model output is calculated, and the rowâ€sums (which should be 1.0 for each sample) are printed. This code is written entirely using JAX and does not rely on any additional libraries.
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | #!/usr/bin/env python3 import jax import jax.numpy as jnp # Initialize parameters for a twoâ€layer fully connected network. def init_mlp_params(key, input_size, hidden_size, output_size): # Split the PRNG key for reproducible independent initializations. key1, key2 = jax.random.split(key, 2) # For the first linear layer, we create weights of shape (input_size, hidden_size) # and biases of shape (hidden_size,). w1 = jax.random.normal(key1, (input_size, hidden_size)) b1 = jnp.zeros((hidden_size,)) # For the second linear layer, we create weights of shape (hidden_size, output_size) # and biases of shape (output_size,). w2 = jax.random.normal(key2, (hidden_size, output_size)) b2 = jnp.zeros((output_size,)) return {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2} # Define the forward pass of the MLP: Linear -> ReLU -> Linear -> Softmax. def mlp_apply(params, x): # First linear layer: x @ W + b x = jnp.dot(x, params['w1']) + params['b1'] # Apply ReLU activation. x = jnp.maximum(x, 0) # Second linear layer. x = jnp.dot(x, params['w2']) + params['b2'] # Compute softmax over the last axis (axis=1) to obtain probabilities. x = jax.nn.softmax(x, axis=1) return x def main(): # Create an initial PRNG key. key = jax.random.PRNGKey(0) # Initialize the MLP parameters with input size=10, hidden size=20, output size=30. params = init_mlp_params(key, 10, 20, 30) # Generate a random input tensor, equivalent to torch.randn(5, 10). key, subkey = jax.random.split(key) x = jax.random.normal(subkey, (5, 10)) # Compute the model's output. y = mlp_apply(params, x) # In PyTorch the code printed: # list(y.flatten().numpy().sum(1)) # Since softmax makes each row sum to 1, we compute the rowâ€sums directly. row_sums = jnp.sum(y, axis=1) # Convert to a Python list and print. print(list(row_sums)) if __name__ == '__main__': main() |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | #!/usr/bin/env python3 import jax import jax.numpy as jnp # Initialize parameters for a twoâ€layer fully connected network. def init_mlp_params(key, input_size, hidden_size, output_size): # Split the PRNG key for reproducible independent initializations. key1, key2 = jax.random.split(key, 2) # For the first linear layer, we create weights of shape (input_size, hidden_size) # and biases of shape (hidden_size,). w1 = jax.random.normal(key1, (input_size, hidden_size)) b1 = jnp.zeros((hidden_size,)) # For the second linear layer, we create weights of shape (hidden_size, output_size) # and biases of shape (output_size,). w2 = jax.random.normal(key2, (hidden_size, output_size)) b2 = jnp.zeros((output_size,)) return {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2} # Define the forward pass of the MLP: Linear -> ReLU -> Linear -> Softmax. def mlp_apply(params, x): # First linear layer: x @ W + b x = jnp.dot(x, params['w1']) + params['b1'] # Apply ReLU activation. x = jnp.maximum(x, 0) # Second linear layer. x = jnp.dot(x, params['w2']) + params['b2'] # Compute softmax over the last axis (axis=1) to obtain probabilities. x = jax.nn.softmax(x, axis=1) return x def main(): # Create an initial PRNG key. key = jax.random.PRNGKey(0) # Initialize the MLP parameters with input size=10, hidden size=20, output size=30. params = init_mlp_params(key, 10, 20, 30) # Generate a random input tensor, equivalent to torch.randn(5, 10). key, subkey = jax.random.split(key) x = jax.random.normal(subkey, (5, 10)) # Compute the model's output. y = mlp_apply(params, x) # In PyTorch the code printed: # list(y.flatten().numpy().sum(1)) # Since softmax makes each row sum to 1, we compute the rowâ€sums directly. row_sums = jnp.sum(y, axis=1) # Convert to a Python list and print. print(list(row_sums)) if __name__ == '__main__': main() |
[Array(1., dtype=float32), Array(0.99999994, dtype=float32), Array(1., dtype=float32), Array(0.99999994, dtype=float32), Array(1., dtype=float32)]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False