https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html를 읽으며 공부하고 있다.
How to Think in JAX — JAX documentation
jax.readthedocs.io
JIT이란 just in time이라는 뜻이다. 이 decorator를 사용한다면, operation이 최적화된 다음에 한 번에 돌아간다. (데코레이터 : https://dojang.io/mod/page/view.php?id=2427 참고) 모든 JAX 코드가 JIT complied 될 수 있는 건 아니다. (static, known at complie time이어야 한다). 아.. 이래서 training에서는 시간이 비슷하고 infernece에서 속도가 빨라진다고 하는 거구나.
import jax.numpy as jnp
from jax import jit
def norm(X):
X = X - X.mean(0)
return X / X.std(0)
norm_compiled = jit(norm)
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
결과는 다음과 같다.
1000 loops, best of 5: 326 µs per loop
The slowest run took 8.28 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 124 µs per loop
일단 여기서 나도 새로운 사실인 np.allclose라는 함수를 알게 되었다. 두 개의 numpy를 비교하는 것 같다.
block_until_ready()는 JAX의 asychronous dispatch를 위해서.. 이용한다고 했다. 최대한 같은 조건으로 가려는 것 같다.
결과를 보면, 후자가 훨씬 빠름을 볼 수 있다.
하지만 JAX를 통해서 JIT과 imcompatible하지 않은 케이스들도 있다.
def get_negatives(x):
return x[x < 0]
x = jnp.array(np.random.randn(10))
get_negatives(x)
jit(get_negatives)(x)
이는 실행되지 않는다.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-19-ec8799cf80d7> in <module>()
----> 1 jit(get_negatives)(x)
IndexError: Array boolean indices must be concrete.
@jit
def f(x, neg):
return -x if neg else x
f(1, True)
이걸 실행하면 오류가 나온다. traced value에 대해서 함수가 depend 할 수 없기 때문이다. 이를 해결해주기 위해서 다음을 시행한다.
from functools import partial
@partial(jit, static_argnums=(1,))
def f(x, neg):
return -x if neg else x
f(1, True)
값이 static or traced가 될 수 있듯, operation도 static or traced다. Static은 compile time에 evaluated 되지만, trace는 run-time in XLA에서 evaluated 된다. 그러므로, numpy는 static을 원할 때, jax는 traced 되기를 원할 때 이용하면 된다.
import jax.numpy as jnp
from jax import jit
@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())
x = jnp.ones((2, 3))
f(x)
이것도 실행이 안 되는데, 이유는 복잡하다. x는 traced value이고, x.shape는 static인데, jnp.array에 넣으면 다시 traced이 되고, jnp.prod도 traced function이 되었는데, 이를 reshape 하면 불가능하기 때문이다. 이를 수정하면 다음과 같은 코드가 가능하다.
from jax import jit
import jax.numpy as jnp
import numpy as np
@jit
def f(x):
return x.reshape((np.prod(x.shape),))
f(x)
그래서, numpy와 jax.numpy를 잘 섞어, static인 것을 numpy로, run-time에서 optimized 되는 jax로 나눠서 잘해야 한다.
'연구 > JAX' 카테고리의 다른 글
JAX 공부 6 - jax as accelerated numpy, grad (0) | 2022.07.19 |
---|---|
JAX 공부 5 - Control Flow (0) | 2022.07.18 |
JAX 공부 4 - out of bounds indexing, non array input, random numbers (0) | 2022.07.18 |
JAX 공부 3 - Pure function, array updates (0) | 2022.07.18 |
JAX 공부 - jax.numpy, jax.lax 기초 (1) | 2022.07.18 |