연구/JAX

JAX 공부 - jax.numpy, jax.lax 기초

Chanwoo Park 2022. 7. 18. 13:09
728x90

JAX는 XLA를 이용해서 Numpy code를 accelerator에서 돌리는 것이다. 요즘 하두 JAX에 대한 말이 많아서, 나도 빠르게 익혀서 내 분야인 강화 학습, 딥러닝에 적용하고자 한다. 기본적으로 JAX document를 하나하나 읽어보고, 이후에 더 추가적으로 구현들을 봐보려 생각한다. 특히 Diffusion Model이나 RL 내에서 적어도 3배 정도 빨라진다 지인들이 추천해서, 안 할 수가 없다. 공부하면서 적은 블로그이니, 틀릴 수 있음을 감안해주면 좋을 거 같다. 

 

생각이 나서, vscode에 colab도 연동해서 공부해보려 한다. (https://dacon.io/forum/406050에서 참고해보려 한다) 평소에는 연구실의 gpu를 썼는데 살짝 갭 타임이 있어서 (MIT에 입학 전에... 복잡하다. 사실 사용은 가능하나 vpn 어쩌고... 너무 귀찮다) colab pro를 잠깐 이용하기로 했다. 

 

https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html

 

How to Think in JAX — JAX documentation

 

jax.readthedocs.io

jax.numpy 는 numpy랑 비슷한 숫자 다루는 툴이라 생각하면 편하다. 

 

import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp)

type(x_jnp)

output : jaxlib.xla_extension.DeviceArray
 

 

np.linspace랑 사실상 같아 보인다. 하지만 Numpy랑 가장 중요하게 다른 것은, immutable 하다는 것이다. JAX에서 하나의 element를 업데이트 하기 위해서는 index update syntax가 필요하다. 
# JAX: immutable arrays
x = jnp.arange(10)
y = x.at[0].set(10)
print(x)
print(y)

output : [0 1 2 3 4 5 6 7 8 9] [10 1 2 3 4 5 6 7 8 9]

x[0] = 10과 같이 진행한다면 업데이트가 안된다. 

 

JAX.numpy와 같이 jax.lax도 있는데, 이는 lower-level API로써 좀 더 엄격하지만 더 powerful 하다고 한다. 예를 들어서, jax.numpy는 1 + 1.0 이 가능한데, lax는 그렇지 않다. 그래서 복잡하게 convolution 같은 것도 translated 되어 있다. 

result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]