연구/JAX

JAX 공부 8 - JIT in JAX

Chanwoo Park 2022. 7. 19. 17:13
728x90

JIT을 공부하기 전에, 내가 위키에서 JIT의 정의를 가져왔다. 사실 JIT이 뭔지 정확히 감이 안 잡혔는데, 위키를 보니 감을 잡아가는 것 같다. 

JIT 컴파일(just-in-time compilation) 또는 동적 번역(dynamic translation)은 프로그램을 실제 실행하는 시점에 기계어로 번역하는 컴파일 기법이다. 전통적인 입장에서 컴퓨터 프로그램을 만드는 방법은 두 가지가 있는데, 인터프리트 방식과 정적 컴파일 방식으로 나눌 수 있다. 이 중 인터프리트 방식은 실행 중 프로그래밍 언어를 읽어가면서 해당 기능에 대응하는 기계어 코드를 실행하며, 반면 정적 컴파일은 실행하기 전에 프로그램 코드를 기계어로 번역한다. JIT 컴파일러는 두 가지의 방식을 혼합한 방식으로 생각할 수 있는데, 실행 시점에서 인터프리트 방식으로 기계어 코드를 생성하면서 그 코드를 캐싱하여, 같은 함수가 여러 번 불릴 때 매번 기계어 코드를 생성하는 것을 방지한다.

일단 JAX는 처음에 함수를 추적 특화해서, 잘 해석하는 중간 형태로 만들고, 변환 특이 interpretation으로 된다 볼 수 있다. 어려운 heavy 한 것을 simple statically typed expression language로 바꾸고, 그것이 jaxpr이다. 이 언어를 내가 해석할 수 있을 필요는 없어 보인다. 예전에 컴퓨터 구조 책 볼 때 비슷한 구조의 언어들이 있었는데 (ex add a b ) 그런 거랑 비슷한 거 같다. 

https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

 

Just In Time Compilation with JAX — JAX documentation

 

jax.readthedocs.io

이 문서를 공부하면서 작성한 글이다. 

 

Tracing을 할 때 JAX는 tracer object로 각 argument를 감싸게 된다. 그 후에 JAX operation을 function call 될 때 전부 기록하게 된다. 그 후 JAX는 이 record를 이용하게 된다. 앞에 

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
이렇게 하면 느린데, selu_jit = jax.jit(selu)를 만들고 돌리면 빠르다. 왜냐하면, selu_jit을 complied version of selu로 만들어두고, selu_jit을 x에 대해서 한 번 돌리고, XLA를 통해서 한 번 컴파일링이 되고, 이후 python implementation을 skip 한다. 
 
하지만 모든 것을 JIT 할 수는 없다. 예를 들어서 condition on the value 같은 것은 불가능하다. 
 
jax.lax.cond를 쓸 수도 있는데, 불가능할 때도 있다. 
 
 
# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)

이런 식으로 간접적으로 피할 수 있다. 혹은 jit이 들어갈 자리를 functools.partial을 이용해서도 가능하다. 

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20))

JIT이 꼭 좋은 건 아닌데, machine learning에서는 복잡한 함수를 계속 반복해서 사용하므로 좋다. 

 

 

Caching behavior of jax.jit을 알면 좋은데, f가 처음 실행되면 그것이 저장되고, 이후 call은 f를 reuse 한다. 

jax.jit을 inside loop에서 시행하는 것을 피해야 한다. 동등한 함수가 재정의될 때 문제가 생길 수 있기 때문이다. (함수의 해시에 캐시가 의존하기 때문)