연구/JAX

JAX 공부 4 - out of bounds indexing, non array input, random numbers

Chanwoo Park 2022. 7. 18. 14:47
728x90

 

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

 

🔪 JAX - The Sharp Bits 🔪 — JAX documentation

 

jax.readthedocs.io

 

Out of bound를 만나면 그냥 실행을 안 해버린다. (넘겨서 실행한다). 

 

 

Non array input은 다음과 같다. 

try:
  jnp.sum([1, 2, 3])
except TypeError as e:
  print(f"TypeError: {e}")
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

즉,  기존의 numpy와 다르게 list를 안 받는 걸 체크할 수 있다. 사실 이 이유는, jax의 디자인과 관련이 있다. trace에 대한 이야기인데, 어쨌든 길게 설명하기보다, 좀 더 엄격하게 받는다고 보면 좋을 것 같다. 위의 코드를 다음과 같이 수정한다. 

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)

여기에도 좀 이슈가 있는 것 같다. 이후에 살펴봐야겠다. 

 

Random Numbers 

from jax import random
key = random.PRNGKey(0)
key
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)

결과는 다음과 같다. 

[-0.20584226]
[0 0]
[-0.20584226]
[0 0]

key가 예상치 못하게 다음과 같이 같은 결과를 내준다. 

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

이런 식으로 key를 나누면서 새로운 key를 propagate 한다. 

key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,)))

 

'연구 > JAX' 카테고리의 다른 글

JAX 공부 6 - jax as accelerated numpy, grad  (0) 2022.07.19
JAX 공부 5 - Control Flow  (0) 2022.07.18
JAX 공부 3 - Pure function, array updates  (0) 2022.07.18
JAX 공부 2 - JIT 기초  (0) 2022.07.18
JAX 공부 - jax.numpy, jax.lax 기초  (1) 2022.07.18