https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html
JAX As Accelerated NumPy — JAX documentation
jax.readthedocs.io
공부하면서 정리하고 있다.
DeviceArray : JAX가 array를 표현하는 방법이다. JAX는 다른 backend (CPU, GPU, TPU)에서 같은 코드로 돌릴 수 있다.
asynchronous dispatch로 인해서 (https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch 참고) 뒤에 추가적인 코드를 붙인다. 쉽게 asynchronous dispatch를 이야기하면 계산을 기다리지 않고 완료 대기 상태로 만든다는 것이다. 그래서 dispatch에 걸리는 시간만 계산할 수 있어서 시간 계산할 때는 좀 다르게 계산해주고, 그 방법이 block_until_ready() 함수이다. 실제로 중요한지는 모르겠으나, 재밌는 개념이라 공부해봤다.
개인적으로 직접 시간을 돌려봤는데, GPU 사용 시 가장 긴 시간이 짧은 시간보다 509배나 빠르다.
jax에서 jax.grad를 이용한다면 함수를 미분할 수 있다.
def sum_of_squares(x):
return jnp.sum(x**2)
sum_of_squares_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(sum_of_squares_dx(x))
[2. 4. 6. 8.]
굉장히 pytorch랑 다른 부분이 잇는데, autodiff library는 (tensor flow나 pytorch에서 이용함). 이는 loss tensor를 이용하게 된다. (loss.backward()). JAX는 function을 이용해서 직접 미분을 계산하고, 기반되는 수학에 가장 가깝게 있으려고 한다. 이런 특성들은 variable들로 미분하는 것이 굉장히 직관적으로 보이게 한다. (first argument로 미분을 하게 만든다) 예를 들어서,
def sum_squared_error(x, y):
return jnp.sum((x-y)**2)
sum_squared_error_dx = jax.grad(sum_squared_error)
y = jnp.asarray([1.1, 2.1, 3.1, 4.1])
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(sum_squared_error_dx(x, y))
[-0.20000005 -0.19999981 -0.19999981 -0.19999981]
여기서 볼 수 있듯, 2(x-y)의 값이 나오고, 즉 첫 번째 argumenet로 나옴을 알 수 있다. 만약 다른 argument들로 되는 것을 보고 싶다면,
jax.grad(sum_squared_error, argnums=(0, 1))(x, y) # Find gradient wrt both x & y
을 이용하면 된다.
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))
근데 머신러닝에는 argument가 많으니 엄청난 array들의 tuple이 나오나 하면 그것은 아니다. JAX는 pytrees라는 데이터 구조를 이용한다. (앞으로의 document에 있다고 하고 길게 봐야 할 것 같아서 당장 보지 않았다.)
def loss_fn(params, data):
...
grads = jax.grad(loss_fn)(params, data_batch)
이런 식으로 일반적으로 사용된다고 한다. 앞으로 쓰이는 과정이 어떻게 되는지 관찰해야겠다.
혹시 value와 grad가 동시에 필요하다고 해도, jax.value_and_grad()을 이용하면 되니 걱정하지 말도록 하자.
jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs))
우리가 코딩을 하다 보면 Auxiliary data들이 필요하게 되는데, 이를 위해서 중간 과정들을 output으로 내주기도 한다. 하지만 이는 잘 안되는데, 그 이유는 jax.grad가 벡터 함수를 미분할 수는 없기 때문이다.
그래서 has_aux라는 것을 이용하게 된다. 다음의 예제 코드로 익히면 편할 것 같다.
def squared_error_with_aux(x, y):
return sum_squared_error(x, y), x-y
jax.grad(squared_error_with_aux)(x, y)
위의 코드는 실행이 안되는데,
jax.grad(squared_error_with_aux, has_aux=True)(x, y)
'연구 > JAX' 카테고리의 다른 글
JAX 공부 8 - JIT in JAX (0) | 2022.07.19 |
---|---|
JAX 공부 7 - difference from numpy (1) | 2022.07.19 |
JAX 공부 5 - Control Flow (0) | 2022.07.18 |
JAX 공부 4 - out of bounds indexing, non array input, random numbers (0) | 2022.07.18 |
JAX 공부 3 - Pure function, array updates (0) | 2022.07.18 |