JAX는 XLA를 이용해서 Numpy code를 accelerator에서 돌리는 것이다. 요즘 하두 JAX에 대한 말이 많아서, 나도 빠르게 익혀서 내 분야인 강화 학습, 딥러닝에 적용하고자 한다. 기본적으로 JAX document를 하나하나 읽어보고, 이후에 더 추가적으로 구현들을 봐보려 생각한다. 특히 Diffusion Model이나 RL 내에서 적어도 3배 정도 빨라진다 지인들이 추천해서, 안 할 수가 없다. 공부하면서 적은 블로그이니, 틀릴 수 있음을 감안해주면 좋을 거 같다.
생각이 나서, vscode에 colab도 연동해서 공부해보려 한다. (https://dacon.io/forum/406050에서 참고해보려 한다) 평소에는 연구실의 gpu를 썼는데 살짝 갭 타임이 있어서 (MIT에 입학 전에... 복잡하다. 사실 사용은 가능하나 vpn 어쩌고... 너무 귀찮다) colab pro를 잠깐 이용하기로 했다.
https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html
How to Think in JAX — JAX documentation
jax.readthedocs.io
jax.numpy 는 numpy랑 비슷한 숫자 다루는 툴이라 생각하면 편하다.
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp)
type(x_jnp)
output : jaxlib.xla_extension.DeviceArray
# JAX: immutable arrays
x = jnp.arange(10)
y = x.at[0].set(10)
print(x)
print(y)
output : [0 1 2 3 4 5 6 7 8 9] [10 1 2 3 4 5 6 7 8 9]
x[0] = 10과 같이 진행한다면 업데이트가 안된다.
JAX.numpy와 같이 jax.lax도 있는데, 이는 lower-level API로써 좀 더 엄격하지만 더 powerful 하다고 한다. 예를 들어서, jax.numpy는 1 + 1.0 이 가능한데, lax는 그렇지 않다. 그래서 복잡하게 convolution 같은 것도 translated 되어 있다.
result = lax.conv_general_dilated(
x.reshape(1, 1, 3).astype(float), # note: explicit promotion
y.reshape(1, 1, 10),
window_strides=(1,),
padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy
result[0, 0]
'연구 > JAX' 카테고리의 다른 글
JAX 공부 6 - jax as accelerated numpy, grad (0) | 2022.07.19 |
---|---|
JAX 공부 5 - Control Flow (0) | 2022.07.18 |
JAX 공부 4 - out of bounds indexing, non array input, random numbers (0) | 2022.07.18 |
JAX 공부 3 - Pure function, array updates (0) | 2022.07.18 |
JAX 공부 2 - JIT 기초 (0) | 2022.07.18 |