연구/JAX 14

JAX 공부 4 - out of bounds indexing, non array input, random numbers

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html 🔪 JAX - The Sharp Bits 🔪 — JAX documentation jax.readthedocs.io Out of bound를 만나면 그냥 실행을 안 해버린다. (넘겨서 실행한다). Non array input은 다음과 같다. try: jnp.sum([1, 2, 3]) except TypeError as e: print(f"TypeError: {e}") TypeError: sum requires ndarray or scalar arguments, got at position 0. 즉, 기존의 numpy와 다르게 list를 안 받는 걸 체크할 수 있다. 사실 이 ..

연구/JAX 2022.07.18

JAX 공부 3 - Pure function, array updates

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html 🔪 JAX - The Sharp Bits 🔪 — JAX documentation jax.readthedocs.io Pure function 개념에 대해서 좀 헷갈려서 적으면서 정리하고자 한다. def impure_print_side_effect(x): print("Executing function") # This is a side-effect return x # The side-effects appear during the first run print ("First call: ", jit(impure_print_side_effect)(4.)) # Subsequent runs..

연구/JAX 2022.07.18

JAX 공부 2 - JIT 기초

https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html를 읽으며 공부하고 있다. How to Think in JAX — JAX documentation jax.readthedocs.io JIT이란 just in time이라는 뜻이다. 이 decorator를 사용한다면, operation이 최적화된 다음에 한 번에 돌아간다. (데코레이터 : https://dojang.io/mod/page/view.php?id=2427 참고) 모든 JAX 코드가 JIT complied 될 수 있는 건 아니다. (static, known at complie time이어야 한다). 아.. 이래서 training에서는 시간이 비슷하고 infernece에서 속도가 ..

연구/JAX 2022.07.18

JAX 공부 - jax.numpy, jax.lax 기초

JAX는 XLA를 이용해서 Numpy code를 accelerator에서 돌리는 것이다. 요즘 하두 JAX에 대한 말이 많아서, 나도 빠르게 익혀서 내 분야인 강화 학습, 딥러닝에 적용하고자 한다. 기본적으로 JAX document를 하나하나 읽어보고, 이후에 더 추가적으로 구현들을 봐보려 생각한다. 특히 Diffusion Model이나 RL 내에서 적어도 3배 정도 빨라진다 지인들이 추천해서, 안 할 수가 없다. 공부하면서 적은 블로그이니, 틀릴 수 있음을 감안해주면 좋을 거 같다. 생각이 나서, vscode에 colab도 연동해서 공부해보려 한다. (https://dacon.io/forum/406050에서 참고해보려 한다) 평소에는 연구실의 gpu를 썼는데 살짝 갭 타임이 있어서 (MIT에 입학 전에..

연구/JAX 2022.07.18