https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html
Automatic Vectorization in JAX — JAX documentation
jax.readthedocs.io
이 문서를 공부하면서 작성한다.
jax.vmap을 이용해서 vectorized implementation을 진행한다. 이는 jax.jit과 비슷하게 automatically adding batch axes 한다. 그냥 잘 쓰면 되는 듯... jit이랑 vmap은 composable 하다고 한다. 위의 문서는 내용이 없는 듯..
https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html
Advanced Automatic Differentiation in JAX — JAX documentation
jax.readthedocs.io
Higher grad를 구할 때는 그냥 jax.grad(jax.grad())... 를 쓰면 된다. (1 변수일 때) 그런데 n변수면 텐서 형태로 n번 미분이 나오기 때문에 어려워지는데, 일단 Jacobian을 계산할 수 있게 f: R^n -> R의 두 번 미분을 할 수 있는 툴을 제공해준다.
jax.jacfwd랑 jax.jacrev가 forward 방향, reverse방향 미분을 뜻하게 된다. 사실 Autodiff Cookbook(https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)을 더 읽어봐야 알 것 같고, 이는 뒷부분 튜토리얼에 있어서 일단 생략하도록 한다.
Higher order differentiation에 대한 optimization이 필요한 분야 중 하나는 Model-Agnostic Meta-Learning(MAML)이다. 개인적으로 이 이론 논문들을 읽고 감동을 받았는데, 코딩이 굉장히 jax에서는 쉬워 보인다.
def meta_loss_fn(params, data):
"""Computes the loss after one step of SGD."""
grads = jax.grad(loss_fn)(params, data)
return loss_fn(params - lr * grads, data)
meta_grads = jax.grad(meta_loss_fn)(params, data)
한편, 우리가 gradient back propagating을 안 하고 시을 때도 있다. 예를 들어서 TD(0) 알고리즘을 생각해보자. (참고 : https://hakucode.tistory.com/8) 이때 현재 state에 대한 theta만 update 해야 하므로, stop을 잘 걸어줘야 한다.
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return (jax.lax.stop_gradient(target) - v_tm1) ** 2
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
이게 뭔 소린가 할 수 있는데, 강화 학습을 좀 잘 알아야 이해할 수 있다. 하고 싶은 말은, stop_gradient를 통해서 일정 부분은 상수로 만들어 줄 수 있다는 점이다.
다른 예시로 Straight-though gradient와 per-example gradient가 있다. jit, vmap, grad를 잘 조합하면 쓸 수 있다.
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
vmap을 좀 더 잘 알아야겠다. in_axes 가 정확히 어떤 의미인지 잘 모르겠다. 예상하기로는 theta 만 제외하고 다 해주는 거 같은데...
'연구 > JAX' 카테고리의 다른 글
JAX 공부 10 - Parallel Evaluation in JAX (0) | 2022.07.20 |
---|---|
JAX 공부 9 - Pytrees (0) | 2022.07.19 |
JAX 공부 8 - JIT in JAX (0) | 2022.07.19 |
JAX 공부 7 - difference from numpy (1) | 2022.07.19 |
JAX 공부 6 - jax as accelerated numpy, grad (0) | 2022.07.19 |