I have list of indexes and I want to convert it to one hot vector using jax and the function should be jittable and the function should be jitted. name the function one_hot and it should get two arguments the first one is the indexes and the second one is the number of possible labeles. Just give me the code
I have list of indexes and I want to convert it to one hot vector using jax and the function should be jittable and the function should be jitted. name the function one_hot and it should get two arguments the first one is the indexes and the second one is the number of possible labeles. Just give me the code
1 2 3 4 5 6 7 8 9 10 11 | import jax import jax.numpy as jnp @jax.jit def one_hot(indexes, num_labels): return jax.nn.one_hot(indexes, num_labels) # Example usage: indexes = jnp.array([0, 1, 2]) num_labels = 3 print(one_hot(indexes, num_labels)) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 | import jax import jax.numpy as jnp @jax.jit def one_hot(indexes, num_labels): return jax.nn.one_hot(indexes, num_labels) # Example usage: indexes = jnp.array([0, 1, 2]) num_labels = 3 print(one_hot(indexes, num_labels)) |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import jax import jax.numpy as jnp @jax.jit def one_hot(indexes, num_labels): return jax.nn.one_hot(indexes, num_labels) # Example usage: indexes = jnp.array([0, 1, 2]) num_labels = 3 print(one_hot(indexes, num_labels)) import jax.numpy as jnp answer = str(one_hot(jnp.array([1,2,0]),5)) expected = str(jnp.array([[0.0,1.0,0.0,0.0,0.0],[0.0,0.0,1.0,0.0,0.0],[1.0,0.0,0.0,0.0,0.0]])) assert answer == expected, f'Wrong answer; got {answer} instead of {expected}' print('All tests passed') |
``
Traceback (most recent call last):
File "/usr/src/app/main.py", line 12, in <module>
print(one_hot(indexes, num_labels))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/src/app/main.py", line 7, in one_hot
return jax.nn.one_hot(indexes, num_labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/nn/functions.py", line 706, in one_hot
num_classes = core.concrete_dim_or_error(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The error arose in jax.nn.one_hot argument
num_classes`.
The error occurred while tracing the function one_hot at /usr/src/app/main.py:5 for jit. This concrete value was not available in Python because it depends on the value of the argument num_labels.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
1 2 | <h1>Substring Evaluation</h1> <p>Testing if the previous output contains the string <code>All tests passed</code>: False</p> |