Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import jax import jax.numpy as jnp import jax.random as random from jax.nn import initializers # Define the MLP structure and forward pass using functions def init_mlp_params(input_size, hidden_size, output_size, key): """Initializes parameters for a two-layer MLP.""" key1, key2 = random.split(key) # Layer 1 parameters (Weight and Bias) # Use Glorot/Xavier initialization, common for linear layers # JAX convention is often (in_features, out_features) for weights w1_key, b1_key = random.split(key1) W1 = initializers.glorot_normal()(w1_key, (input_size, hidden_size)) b1 = initializers.zeros(b1_key, (hidden_size,)) # Initialize biases to zero # Layer 2 parameters (Weight and Bias) w2_key, b2_key = random.split(key2) W2 = initializers.glorot_normal()(w2_key, (hidden_size, output_size)) b2 = initializers.zeros(b2_key, (output_size,)) # Store parameters in a structure (e.g., a list of tuples or a dictionary) # Using a list of tuples (W, b) for each layer params = [(W1, b1), (W2, b2)] return params # Define the forward pass function # Use jax.jit for potential performance improvement through compilation @jax.jit def mlp_predict(params, x): """MLP forward pass.""" # Unpack parameters (W1, b1), (W2, b2) = params # First linear layer + ReLU activation x = jnp.dot(x, W1) + b1 x = jax.nn.relu(x) # Second linear layer x = jnp.dot(x, W2) + b2 # Softmax activation # axis=1 corresponds to PyTorch's dim=1 for shape (batch, features) return jax.nn.softmax(x, axis=1) # --- Execution Example --- # Define model dimensions input_size = 10 hidden_size = 20 output_size = 30 batch_size = 5 # Create a random key for initialization and data generation key = random.PRNGKey(0) # Use a fixed seed for reproducibility init_key, data_key = random.split(key) # Initialize model parameters params = init_mlp_params(input_size, hidden_size, output_size, init_key) # Generate random input data (equivalent to torch.randn) x = random.normal(data_key, (batch_size, input_size)) # Perform the forward pass y = mlp_predict(params, x) # Print the sum of probabilities for each item in the batch # Note: The original PyTorch print statement `list(y.flatten().numpy().sum(1))` # would likely cause an error because after flattening, the array is 1D and # sum(1) is invalid. # A more meaningful operation after softmax is often summing probabilities # per batch item (axis=1), which should be close to 1.0. # We replicate this meaningful operation here using JAX arrays and .tolist(). print(y.sum(axis=1).tolist()) # If you strictly wanted the sum of *all* elements in the flattened output: # print(y.flatten().sum().tolist()) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import jax import jax.numpy as jnp import jax.random as random from jax.nn import initializers # Define the MLP structure and forward pass using functions def init_mlp_params(input_size, hidden_size, output_size, key): """Initializes parameters for a two-layer MLP.""" key1, key2 = random.split(key) # Layer 1 parameters (Weight and Bias) # Use Glorot/Xavier initialization, common for linear layers # JAX convention is often (in_features, out_features) for weights w1_key, b1_key = random.split(key1) W1 = initializers.glorot_normal()(w1_key, (input_size, hidden_size)) b1 = initializers.zeros(b1_key, (hidden_size,)) # Initialize biases to zero # Layer 2 parameters (Weight and Bias) w2_key, b2_key = random.split(key2) W2 = initializers.glorot_normal()(w2_key, (hidden_size, output_size)) b2 = initializers.zeros(b2_key, (output_size,)) # Store parameters in a structure (e.g., a list of tuples or a dictionary) # Using a list of tuples (W, b) for each layer params = [(W1, b1), (W2, b2)] return params # Define the forward pass function # Use jax.jit for potential performance improvement through compilation @jax.jit def mlp_predict(params, x): """MLP forward pass.""" # Unpack parameters (W1, b1), (W2, b2) = params # First linear layer + ReLU activation x = jnp.dot(x, W1) + b1 x = jax.nn.relu(x) # Second linear layer x = jnp.dot(x, W2) + b2 # Softmax activation # axis=1 corresponds to PyTorch's dim=1 for shape (batch, features) return jax.nn.softmax(x, axis=1) # --- Execution Example --- # Define model dimensions input_size = 10 hidden_size = 20 output_size = 30 batch_size = 5 # Create a random key for initialization and data generation key = random.PRNGKey(0) # Use a fixed seed for reproducibility init_key, data_key = random.split(key) # Initialize model parameters params = init_mlp_params(input_size, hidden_size, output_size, init_key) # Generate random input data (equivalent to torch.randn) x = random.normal(data_key, (batch_size, input_size)) # Perform the forward pass y = mlp_predict(params, x) # Print the sum of probabilities for each item in the batch # Note: The original PyTorch print statement `list(y.flatten().numpy().sum(1))` # would likely cause an error because after flattening, the array is 1D and # sum(1) is invalid. # A more meaningful operation after softmax is often summing probabilities # per batch item (axis=1), which should be close to 1.0. # We replicate this meaningful operation here using JAX arrays and .tolist(). print(y.sum(axis=1).tolist()) # If you strictly wanted the sum of *all* elements in the flattened output: # print(y.flatten().sum().tolist()) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import jax import jax.numpy as jnp import jax.random as random from jax.nn import initializers # Define the MLP structure and forward pass using functions def init_mlp_params(input_size, hidden_size, output_size, key): """Initializes parameters for a two-layer MLP.""" key1, key2 = random.split(key) # Layer 1 parameters (Weight and Bias) # Use Glorot/Xavier initialization, common for linear layers # JAX convention is often (in_features, out_features) for weights w1_key, b1_key = random.split(key1) W1 = initializers.glorot_normal()(w1_key, (input_size, hidden_size)) b1 = initializers.zeros(b1_key, (hidden_size,)) # Initialize biases to zero # Layer 2 parameters (Weight and Bias) w2_key, b2_key = random.split(key2) W2 = initializers.glorot_normal()(w2_key, (hidden_size, output_size)) b2 = initializers.zeros(b2_key, (output_size,)) # Store parameters in a structure (e.g., a list of tuples or a dictionary) # Using a list of tuples (W, b) for each layer params = [(W1, b1), (W2, b2)] return params # Define the forward pass function # Use jax.jit for potential performance improvement through compilation @jax.jit def mlp_predict(params, x): """MLP forward pass.""" # Unpack parameters (W1, b1), (W2, b2) = params # First linear layer + ReLU activation x = jnp.dot(x, W1) + b1 x = jax.nn.relu(x) # Second linear layer x = jnp.dot(x, W2) + b2 # Softmax activation # axis=1 corresponds to PyTorch's dim=1 for shape (batch, features) return jax.nn.softmax(x, axis=1) # --- Execution Example --- # Define model dimensions input_size = 10 hidden_size = 20 output_size = 30 batch_size = 5 # Create a random key for initialization and data generation key = random.PRNGKey(0) # Use a fixed seed for reproducibility init_key, data_key = random.split(key) # Initialize model parameters params = init_mlp_params(input_size, hidden_size, output_size, init_key) # Generate random input data (equivalent to torch.randn) x = random.normal(data_key, (batch_size, input_size)) # Perform the forward pass y = mlp_predict(params, x) # Print the sum of probabilities for each item in the batch # Note: The original PyTorch print statement `list(y.flatten().numpy().sum(1))` # would likely cause an error because after flattening, the array is 1D and # sum(1) is invalid. # A more meaningful operation after softmax is often summing probabilities # per batch item (axis=1), which should be close to 1.0. # We replicate this meaningful operation here using JAX arrays and .tolist(). print(y.sum(axis=1).tolist()) # If you strictly wanted the sum of *all* elements in the flattened output: # print(y.flatten().sum().tolist()) |
[0.9999998807907104, 1.0, 0.9999999403953552, 1.0, 1.0]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: True
Testing if the previous output contains the string import torch
: False