Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
### JAX Implementation of the Provided PyTorch Code Below is the equivalent JAX implementation of the given PyTorch code. Note that JAX uses a functional programming paradigm, so we will define the model's forward pass as a function.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import jax import jax.numpy as jnp # Define the MLP model def mlp_forward(params, x): # Unpack the parameters linear1_weights, linear1_bias, linear2_weights, linear2_bias = params # First linear layer x = jnp.dot(x, linear1_weights) + linear1_bias x = jax.nn.relu(x) # Second linear layer x = jnp.dot(x, linear2_weights) + linear2_bias # Softmax activation function return jax.nn.softmax(x, axis=1) # Initialize the model parameters def init_mlp(input_size, hidden_size, output_size, key): linear1_weights = jax.random.normal(key, (input_size, hidden_size)) linear1_bias = jax.random.normal(key, (hidden_size,)) linear2_weights = jax.random.normal(key, (hidden_size, output_size)) linear2_bias = jax.random.normal(key, (output_size,)) return (linear1_weights, linear1_bias, linear2_weights, linear2_bias) # Main function def main(): # Set the random seed key = jax.random.PRNGKey(0) # Initialize the model input_size, hidden_size, output_size = 10, 20, 30 params = init_mlp(input_size, hidden_size, output_size, key) # Generate a random input x = jax.random.normal(key, (5, input_size)) # Forward pass y = mlp_forward(params, x) # Print the sum of the output along the last axis print(jnp.sum(y, axis=1)) if __name__ == "__main__": main() |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import jax import jax.numpy as jnp # Define the MLP model def mlp_forward(params, x): # Unpack the parameters linear1_weights, linear1_bias, linear2_weights, linear2_bias = params # First linear layer x = jnp.dot(x, linear1_weights) + linear1_bias x = jax.nn.relu(x) # Second linear layer x = jnp.dot(x, linear2_weights) + linear2_bias # Softmax activation function return jax.nn.softmax(x, axis=1) # Initialize the model parameters def init_mlp(input_size, hidden_size, output_size, key): linear1_weights = jax.random.normal(key, (input_size, hidden_size)) linear1_bias = jax.random.normal(key, (hidden_size,)) linear2_weights = jax.random.normal(key, (hidden_size, output_size)) linear2_bias = jax.random.normal(key, (output_size,)) return (linear1_weights, linear1_bias, linear2_weights, linear2_bias) # Main function def main(): # Set the random seed key = jax.random.PRNGKey(0) # Initialize the model input_size, hidden_size, output_size = 10, 20, 30 params = init_mlp(input_size, hidden_size, output_size, key) # Generate a random input x = jax.random.normal(key, (5, input_size)) # Forward pass y = mlp_forward(params, x) # Print the sum of the output along the last axis print(jnp.sum(y, axis=1)) if __name__ == "__main__": main() |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import jax import jax.numpy as jnp # Define the MLP model def mlp_forward(params, x): # Unpack the parameters linear1_weights, linear1_bias, linear2_weights, linear2_bias = params # First linear layer x = jnp.dot(x, linear1_weights) + linear1_bias x = jax.nn.relu(x) # Second linear layer x = jnp.dot(x, linear2_weights) + linear2_bias # Softmax activation function return jax.nn.softmax(x, axis=1) # Initialize the model parameters def init_mlp(input_size, hidden_size, output_size, key): linear1_weights = jax.random.normal(key, (input_size, hidden_size)) linear1_bias = jax.random.normal(key, (hidden_size,)) linear2_weights = jax.random.normal(key, (hidden_size, output_size)) linear2_bias = jax.random.normal(key, (output_size,)) return (linear1_weights, linear1_bias, linear2_weights, linear2_bias) # Main function def main(): # Set the random seed key = jax.random.PRNGKey(0) # Initialize the model input_size, hidden_size, output_size = 10, 20, 30 params = init_mlp(input_size, hidden_size, output_size, key) # Generate a random input x = jax.random.normal(key, (5, input_size)) # Forward pass y = mlp_forward(params, x) # Print the sum of the output along the last axis print(jnp.sum(y, axis=1)) if __name__ == "__main__": main() |
[1. 0.99999994 1. 1.0000001 1. ]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: True
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False