728x90
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
🔪 JAX - The Sharp Bits 🔪 — JAX documentation
jax.readthedocs.io
Out of bound를 만나면 그냥 실행을 안 해버린다. (넘겨서 실행한다).
Non array input은 다음과 같다.
try:
jnp.sum([1, 2, 3])
except TypeError as e:
print(f"TypeError: {e}")
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
즉, 기존의 numpy와 다르게 list를 안 받는 걸 체크할 수 있다. 사실 이 이유는, jax의 디자인과 관련이 있다. trace에 대한 이야기인데, 어쨌든 길게 설명하기보다, 좀 더 엄격하게 받는다고 보면 좋을 것 같다. 위의 코드를 다음과 같이 수정한다.
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
여기에도 좀 이슈가 있는 것 같다. 이후에 살펴봐야겠다.
Random Numbers
from jax import random
key = random.PRNGKey(0)
key
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
결과는 다음과 같다.
[-0.20584226]
[0 0]
[-0.20584226]
[0 0]
key가 예상치 못하게 다음과 같이 같은 결과를 내준다.
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
이런 식으로 key를 나누면서 새로운 key를 propagate 한다.
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
'연구 > JAX' 카테고리의 다른 글
JAX 공부 6 - jax as accelerated numpy, grad (0) | 2022.07.19 |
---|---|
JAX 공부 5 - Control Flow (0) | 2022.07.18 |
JAX 공부 3 - Pure function, array updates (0) | 2022.07.18 |
JAX 공부 2 - JIT 기초 (0) | 2022.07.18 |
JAX 공부 - jax.numpy, jax.lax 기초 (1) | 2022.07.18 |