연구/JAX

JAX 공부 8 - Automatic Vectorization in JAX, Automatic Differentiation in JAX

Chanwoo Park 2022. 7. 19. 18:13
728x90

https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html

 

Automatic Vectorization in JAX — JAX documentation

 

jax.readthedocs.io

이 문서를 공부하면서 작성한다. 

jax.vmap을 이용해서 vectorized implementation을 진행한다. 이는 jax.jit과 비슷하게 automatically adding batch axes 한다. 그냥 잘 쓰면 되는 듯... jit이랑 vmap은 composable 하다고 한다. 위의 문서는 내용이 없는 듯..

 

https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html

 

Advanced Automatic Differentiation in JAX — JAX documentation

 

jax.readthedocs.io

Higher grad를 구할 때는 그냥 jax.grad(jax.grad())... 를 쓰면 된다. (1 변수일 때) 그런데 n변수면 텐서 형태로 n번 미분이 나오기 때문에 어려워지는데, 일단 Jacobian을 계산할 수 있게 f: R^n -> R의 두 번 미분을 할 수 있는 툴을 제공해준다. 

jax.jacfwd랑 jax.jacrev가 forward 방향, reverse방향 미분을 뜻하게 된다. 사실 Autodiff Cookbook(https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)을 더 읽어봐야 알 것 같고, 이는 뒷부분 튜토리얼에 있어서 일단 생략하도록 한다. 

 

Higher order differentiation에 대한 optimization이 필요한 분야 중 하나는 Model-Agnostic Meta-Learning(MAML)이다. 개인적으로 이 이론 논문들을 읽고 감동을 받았는데, 코딩이 굉장히 jax에서는 쉬워 보인다. 

def meta_loss_fn(params, data):
  """Computes the loss after one step of SGD."""
  grads = jax.grad(loss_fn)(params, data)
  return loss_fn(params - lr * grads, data)

meta_grads = jax.grad(meta_loss_fn)(params, data)

 

한편, 우리가 gradient back propagating을 안 하고 시을 때도 있다. 예를 들어서 TD(0) 알고리즘을 생각해보자. (참고 : https://hakucode.tistory.com/8) 이때 현재 state에 대한 theta만 update 해야 하므로, stop을 잘 걸어줘야 한다. 

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (jax.lax.stop_gradient(target) - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta

이게 뭔 소린가 할 수 있는데, 강화 학습을 좀 잘 알아야 이해할 수 있다. 하고 싶은 말은, stop_gradient를 통해서 일정 부분은 상수로 만들어 줄 수 있다는 점이다. 

 

다른 예시로 Straight-though gradient와 per-example gradient가 있다. jit, vmap, grad를 잘 조합하면 쓸 수 있다. 

perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

 vmap을 좀 더 잘 알아야겠다. in_axes 가 정확히 어떤 의미인지 잘 모르겠다. 예상하기로는 theta 만 제외하고 다 해주는 거 같은데... 

'연구 > JAX' 카테고리의 다른 글

JAX 공부 10 - Parallel Evaluation in JAX  (0) 2022.07.20
JAX 공부 9 - Pytrees  (0) 2022.07.19
JAX 공부 8 - JIT in JAX  (0) 2022.07.19
JAX 공부 7 - difference from numpy  (1) 2022.07.19
JAX 공부 6 - jax as accelerated numpy, grad  (0) 2022.07.19