Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | import jax import jax.numpy as jnp from jax import random key = random.PRNGKey(0) def init_layer(key, m, n): key, subkey = random.split(key) w_key, b_key = random.split(subkey) w = random.normal(w_key, (n, m)) * (2 / m)**0.5 # Kaiming initialization b = random.normal(b_key, (n,)) return key, (w, b) class MLP: def __init__(self, key, input_size, hidden_size, output_size): key, layer1 = init_layer(key, input_size, hidden_size) key, layer2 = init_layer(key, hidden_size, output_size) self.layer1 = layer1 self.layer2 = layer2 def __call__(self, x): w1, b1 = self.layer1 x = jnp.dot(x, w1.T) + b1 x = jax.nn.relu(x) w2, b2 = self.layer2 x = jnp.dot(x, w2.T) + b2 return jax.nn.softmax(x, axis=1) input_size = 10 hidden_size = 20 output_size = 30 model = MLP(key, input_size, hidden_size, output_size) x = random.normal(key, (5, input_size)) y = model(x) print(list(y.flatten().sum(1))) # No need to convert to numpy in JAX |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | import jax import jax.numpy as jnp from jax import random key = random.PRNGKey(0) def init_layer(key, m, n): key, subkey = random.split(key) w_key, b_key = random.split(subkey) w = random.normal(w_key, (n, m)) * (2 / m)**0.5 # Kaiming initialization b = random.normal(b_key, (n,)) return key, (w, b) class MLP: def __init__(self, key, input_size, hidden_size, output_size): key, layer1 = init_layer(key, input_size, hidden_size) key, layer2 = init_layer(key, hidden_size, output_size) self.layer1 = layer1 self.layer2 = layer2 def __call__(self, x): w1, b1 = self.layer1 x = jnp.dot(x, w1.T) + b1 x = jax.nn.relu(x) w2, b2 = self.layer2 x = jnp.dot(x, w2.T) + b2 return jax.nn.softmax(x, axis=1) input_size = 10 hidden_size = 20 output_size = 30 model = MLP(key, input_size, hidden_size, output_size) x = random.normal(key, (5, input_size)) y = model(x) print(list(y.flatten().sum(1))) # No need to convert to numpy in JAX |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | import jax import jax.numpy as jnp from jax import random key = random.PRNGKey(0) def init_layer(key, m, n): key, subkey = random.split(key) w_key, b_key = random.split(subkey) w = random.normal(w_key, (n, m)) * (2 / m)**0.5 # Kaiming initialization b = random.normal(b_key, (n,)) return key, (w, b) class MLP: def __init__(self, key, input_size, hidden_size, output_size): key, layer1 = init_layer(key, input_size, hidden_size) key, layer2 = init_layer(key, hidden_size, output_size) self.layer1 = layer1 self.layer2 = layer2 def __call__(self, x): w1, b1 = self.layer1 x = jnp.dot(x, w1.T) + b1 x = jax.nn.relu(x) w2, b2 = self.layer2 x = jnp.dot(x, w2.T) + b2 return jax.nn.softmax(x, axis=1) input_size = 10 hidden_size = 20 output_size = 30 model = MLP(key, input_size, hidden_size, output_size) x = random.normal(key, (5, input_size)) y = model(x) print(list(y.flatten().sum(1))) # No need to convert to numpy in JAX |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | Traceback (most recent call last): File "/usr/src/app/main.py", line 39, in <module> print(list(y.flatten().sum(1))) # No need to convert to numpy in JAX ^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/array_methods.py", line 363, in _sum return reductions.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 289, in sum return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 213, in _reduce_sum return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 111, in _reduction pos_dims, dims = _reduction_dims(a, axis) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 160, in _reduction_dims canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 160, in <genexpr> canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 153, in _canonicalize_axis_allow_named return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 153, in <lambda> return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: axis 1 is out of bounds for array of dimension 1</p> <hr /> <p>For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import jax import jax.numpy as jnp from jax import random key = random.PRNGKey(0) class MLP: def __init__(self, input_size, hidden_size, output_size, key): key1, key2 = random.split(key) self.linear1 = Linear(input_size, hidden_size, key1) self.linear2 = Linear(hidden_size, output_size, key2) def __call__(self, x): x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return jax.nn.softmax(x, axis=1) class Linear: def __init__(self, input_size, output_size, key): w_key, b_key = random.split(key) self.w = random.normal(w_key, (input_size, output_size)) * jax.lax.rsqrt(input_size) # Xavier initialization self.b = jnp.zeros((output_size,)) def __call__(self, x): return jnp.dot(x, self.w) + self.b model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.flatten().sum(1))) # No numpy conversion needed in JAX |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import jax import jax.numpy as jnp from jax import random key = random.PRNGKey(0) class MLP: def __init__(self, input_size, hidden_size, output_size, key): key1, key2 = random.split(key) self.linear1 = Linear(input_size, hidden_size, key1) self.linear2 = Linear(hidden_size, output_size, key2) def __call__(self, x): x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return jax.nn.softmax(x, axis=1) class Linear: def __init__(self, input_size, output_size, key): w_key, b_key = random.split(key) self.w = random.normal(w_key, (input_size, output_size)) * jax.lax.rsqrt(input_size) # Xavier initialization self.b = jnp.zeros((output_size,)) def __call__(self, x): return jnp.dot(x, self.w) + self.b model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.flatten().sum(1))) # No numpy conversion needed in JAX |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import jax import jax.numpy as jnp from jax import random key = random.PRNGKey(0) class MLP: def __init__(self, input_size, hidden_size, output_size, key): key1, key2 = random.split(key) self.linear1 = Linear(input_size, hidden_size, key1) self.linear2 = Linear(hidden_size, output_size, key2) def __call__(self, x): x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return jax.nn.softmax(x, axis=1) class Linear: def __init__(self, input_size, output_size, key): w_key, b_key = random.split(key) self.w = random.normal(w_key, (input_size, output_size)) * jax.lax.rsqrt(input_size) # Xavier initialization self.b = jnp.zeros((output_size,)) def __call__(self, x): return jnp.dot(x, self.w) + self.b model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.flatten().sum(1))) # No numpy conversion needed in JAX |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | Traceback (most recent call last): File "/usr/src/app/main.py", line 32, in <module> model = MLP(10, 20, 30, key) ^^^^^^^^^^^^^^^^^^^^ File "/usr/src/app/main.py", line 12, in <strong>init</strong> self.linear1 = Linear(input_size, hidden_size, key1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/src/app/main.py", line 24, in <strong>init</strong> self.w = random.normal(w_key, (input_size, output_size)) * jax.lax.rsqrt(input_size) # Xavier initialization ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/lax/lax.py", line 395, in rsqrt return rsqrt_p.bind(x) ^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/core.py", line 438, in bind return self.bind_with_trace(find_top_trace(args), args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/core.py", line 442, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/core.py", line 955, in process_primitive return primitive.impl(<em>tracers, </em><em>params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/dispatch.py", line 91, in apply_primitive outs = fun(</em>args) ^^^^^^^^^^ TypeError: rsqrt does not accept dtype int32. Accepted dtypes are subtypes of floating, complexfloating.</p> <hr /> <p>For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = random.split(key) self.linear1_w = random.normal(k1, (input_size, hidden_size)) self.linear1_b = jnp.zeros(hidden_size) self.linear2_w = random.normal(k2, (hidden_size, output_size)) self.linear2_b = jnp.zeros(output_size) def __call__(self, x): x = jnp.dot(x, self.linear1_w) + self.linear1_b x = jax.nn.relu(x) x = jnp.dot(x, self.linear2_w) + self.linear2_b return jax.nn.softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.flatten().sum(1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = random.split(key) self.linear1_w = random.normal(k1, (input_size, hidden_size)) self.linear1_b = jnp.zeros(hidden_size) self.linear2_w = random.normal(k2, (hidden_size, output_size)) self.linear2_b = jnp.zeros(output_size) def __call__(self, x): x = jnp.dot(x, self.linear1_w) + self.linear1_b x = jax.nn.relu(x) x = jnp.dot(x, self.linear2_w) + self.linear2_b return jax.nn.softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.flatten().sum(1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = random.split(key) self.linear1_w = random.normal(k1, (input_size, hidden_size)) self.linear1_b = jnp.zeros(hidden_size) self.linear2_w = random.normal(k2, (hidden_size, output_size)) self.linear2_b = jnp.zeros(output_size) def __call__(self, x): x = jnp.dot(x, self.linear1_w) + self.linear1_b x = jax.nn.relu(x) x = jnp.dot(x, self.linear2_w) + self.linear2_b return jax.nn.softmax(x, axis=1) key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(key, (5, 10)) y = model(x) print(list(y.flatten().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | Traceback (most recent call last): File "/usr/src/app/main.py", line 24, in <module> print(list(y.flatten().sum(1))) ^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/array_methods.py", line 363, in _sum return reductions.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 289, in sum return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 213, in _reduce_sum return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 111, in _reduction pos_dims, dims = _reduction_dims(a, axis) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 160, in _reduction_dims canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 160, in <genexpr> canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 153, in _canonicalize_axis_allow_named return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 153, in <lambda> return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: axis 1 is out of bounds for array of dimension 1</p> <hr /> <p>For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import jax import jax.numpy as jnp class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = jax.random.split(key) self.linear1_w = jax.random.normal(k1, (hidden_size, input_size)) self.linear1_b = jnp.zeros((hidden_size,)) self.linear2_w = jax.random.normal(k2, (output_size, hidden_size)) self.linear2_b = jnp.zeros((output_size,)) def __call__(self, x): x = jnp.dot(x, self.linear1_w.T) + self.linear1_b x = jax.nn.relu(x) x = jnp.dot(x, self.linear2_w.T) + self.linear2_b return jax.nn.softmax(x, axis=1) key = jax.random.PRNGKey(0) model = MLP(10, 20, 30, key) x = jax.random.normal(key, (5, 10)) y = model(x) print(jnp.sum(y, axis=1)) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import jax import jax.numpy as jnp class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = jax.random.split(key) self.linear1_w = jax.random.normal(k1, (hidden_size, input_size)) self.linear1_b = jnp.zeros((hidden_size,)) self.linear2_w = jax.random.normal(k2, (output_size, hidden_size)) self.linear2_b = jnp.zeros((output_size,)) def __call__(self, x): x = jnp.dot(x, self.linear1_w.T) + self.linear1_b x = jax.nn.relu(x) x = jnp.dot(x, self.linear2_w.T) + self.linear2_b return jax.nn.softmax(x, axis=1) key = jax.random.PRNGKey(0) model = MLP(10, 20, 30, key) x = jax.random.normal(key, (5, 10)) y = model(x) print(jnp.sum(y, axis=1)) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import jax import jax.numpy as jnp class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = jax.random.split(key) self.linear1_w = jax.random.normal(k1, (hidden_size, input_size)) self.linear1_b = jnp.zeros((hidden_size,)) self.linear2_w = jax.random.normal(k2, (output_size, hidden_size)) self.linear2_b = jnp.zeros((output_size,)) def __call__(self, x): x = jnp.dot(x, self.linear1_w.T) + self.linear1_b x = jax.nn.relu(x) x = jnp.dot(x, self.linear2_w.T) + self.linear2_b return jax.nn.softmax(x, axis=1) key = jax.random.PRNGKey(0) model = MLP(10, 20, 30, key) x = jax.random.normal(key, (5, 10)) y = model(x) print(jnp.sum(y, axis=1)) |
[0.99999994 0.99999994 1. 1. 1.0000001 ]
Testing if the previous output contains the string 0.99999
: True
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: True
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
Convert the following PyTorch code to pure JAX. Do not use any other libraries.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.linear1 = torch.nn.Linear(input_size, hidden_size) self.linear2 = torch.nn.Linear(hidden_size, output_size) def __call__(self, x): x = self.linear1(x) x = torch.relu(x) x = self.linear2(x) return torch.nn.functional.softmax(x, dim=1) model = MLP(10, 20, 30) x = torch.randn(5, 10) y = model(x) print(list(y.flatten().numpy().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = random.split(key) self.linear1 = Linear(input_size, hidden_size, k1) self.linear2 = Linear(hidden_size, output_size, k2) def __call__(self, x): x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return jax.nn.softmax(x, axis=1) class Linear: def __init__(self, in_features, out_features, key): wkey, bkey = random.split(key) self.w = random.normal(wkey, (out_features, in_features)) self.b = random.normal(bkey, (out_features,)) def __call__(self, x): return jnp.dot(x, self.w.T) + self.b key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(x) print(list(y.flatten().sum(1))) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = random.split(key) self.linear1 = Linear(input_size, hidden_size, k1) self.linear2 = Linear(hidden_size, output_size, k2) def __call__(self, x): x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return jax.nn.softmax(x, axis=1) class Linear: def __init__(self, in_features, out_features, key): wkey, bkey = random.split(key) self.w = random.normal(wkey, (out_features, in_features)) self.b = random.normal(bkey, (out_features,)) def __call__(self, x): return jnp.dot(x, self.w.T) + self.b key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(x) print(list(y.flatten().sum(1))) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import jax import jax.numpy as jnp from jax import random class MLP: def __init__(self, input_size, hidden_size, output_size, key): k1, k2 = random.split(key) self.linear1 = Linear(input_size, hidden_size, k1) self.linear2 = Linear(hidden_size, output_size, k2) def __call__(self, x): x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return jax.nn.softmax(x, axis=1) class Linear: def __init__(self, in_features, out_features, key): wkey, bkey = random.split(key) self.w = random.normal(wkey, (out_features, in_features)) self.b = random.normal(bkey, (out_features,)) def __call__(self, x): return jnp.dot(x, self.w.T) + self.b key = random.PRNGKey(0) model = MLP(10, 20, 30, key) x = random.normal(random.PRNGKey(1), (5, 10)) y = model(x) print(list(y.flatten().sum(1))) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | Traceback (most recent call last): File "/usr/src/app/main.py", line 33, in <module> print(list(y.flatten().sum(1))) ^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/array_methods.py", line 363, in _sum return reductions.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 289, in sum return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 213, in _reduce_sum return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 111, in _reduction pos_dims, dims = _reduction_dims(a, axis) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 160, in _reduction_dims canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 160, in <genexpr> canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 153, in _canonicalize_axis_allow_named return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/reductions.py", line 153, in <lambda> return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: axis 1 is out of bounds for array of dimension 1</p> <hr /> <p>For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
Testing if the previous output contains the string 0.99999
: False
Testing if the previous output contains the string 1.0
: False
Testing if the previous output contains the string 1.00000
: False
Testing if the previous output contains the string 1.0,
: False
Testing if the previous output contains the string import torch
: False