JAX는 XLA를 이용해서 Numpy code를 accelerator에서 돌리는 것이다. 요즘 하두 JAX에 대한 말이 많아서, 나도 빠르게 익혀서 내 분야인 강화 학습, 딥러닝에 적용하고자 한다. 기본적으로 JAX document를 하나하나 읽어보고, 이후에 더 추가적으로 구현들을 봐보려 생각한다. 특히 Diffusion Model이나 RL 내에서 적어도 3배 정도 빨라진다 지인들이 추천해서, 안 할 수가 없다. 공부하면서 적은 블로그이니, 틀릴 수 있음을 감안해주면 좋을 거 같다. 생각이 나서, vscode에 colab도 연동해서 공부해보려 한다. (https://dacon.io/forum/406050에서 참고해보려 한다) 평소에는 연구실의 gpu를 썼는데 살짝 갭 타임이 있어서 (MIT에 입학 전에..