연구/JAX

JAX 공부 7 - difference from numpy

Chanwoo Park 2022. 7. 19. 16:24
728x90

https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html

 

JAX As Accelerated NumPy — JAX documentation

 

jax.readthedocs.io

을 공부하면서 작성하고 있다. 

 

JAX는 함수형 프로그래밍을 하기 위해서 만들어졌다는 게 중요하다. 아... 근데 내가 직접 이 문서를 볼수록 나랑 잘 맞는 거 같다. 나는 학창 시절에도 python보다 c++ 이 더 편했다. python의 ambiguity가 나를 더 힘들게 했는데, 점점 재미가 붙는 거 같다. 아무튼.. 계속해서 설명하면 

JAX를 익히는데 함수형 프로그래밍을 잘 알 필요는 없다. 사실 나도 함수형 프로그래밍을 모른다. 다만 functional programming이라는 개념이 JAX를 이용한 코딩에 도움이 되는 것도 분명하다. 

 

Side-effect-free 한 코딩을 해야하는데, 예를 들어서 프린트를 넣으면 안 된다. 흠.. 나는 폭풍 프린트를 이용한 디버깅을 좋아하는데... 그래서 좀 찾아봤다. https://github.com/google/jax/issues/4615를 참고해보면 좋을 거 같다. (미래에 내가 쓸 거 같다) 

 

이런 것 중 가장 재밌는 것은 in-place modification인데, 사실 고치는 것은 안되고 이는 새로운 array를 modicifation이 만들어지게 하고 만드는 것이다. (대체가 아님) 원래의 것은 untouch 되어 no side effect가 된다. 예시 코드를 보자. 

 

def jax_in_place_modify(x):
  return x.at[0].set(123)

y = jnp.array([1, 2, 3])
jax_in_place_modify(y)
DeviceArray([123,   2,   3], dtype=int32)
y
DeviceArray([1, 2, 3], dtype=int32)

그럼 실제로 새로운 array를 만드니까 매우 비효율 적이 아닌가? JAX는 JIT을 이용하기 전에 먼저 complied 되기도 한다. 장점들을 잘 섞어서 사용해야 할 듯.