https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html
JAX As Accelerated NumPy — JAX documentation
jax.readthedocs.io
을 공부하면서 작성하고 있다.
JAX는 함수형 프로그래밍을 하기 위해서 만들어졌다는 게 중요하다. 아... 근데 내가 직접 이 문서를 볼수록 나랑 잘 맞는 거 같다. 나는 학창 시절에도 python보다 c++ 이 더 편했다. python의 ambiguity가 나를 더 힘들게 했는데, 점점 재미가 붙는 거 같다. 아무튼.. 계속해서 설명하면
JAX를 익히는데 함수형 프로그래밍을 잘 알 필요는 없다. 사실 나도 함수형 프로그래밍을 모른다. 다만 functional programming이라는 개념이 JAX를 이용한 코딩에 도움이 되는 것도 분명하다.
Side-effect-free 한 코딩을 해야하는데, 예를 들어서 프린트를 넣으면 안 된다. 흠.. 나는 폭풍 프린트를 이용한 디버깅을 좋아하는데... 그래서 좀 찾아봤다. https://github.com/google/jax/issues/4615를 참고해보면 좋을 거 같다. (미래에 내가 쓸 거 같다)
이런 것 중 가장 재밌는 것은 in-place modification인데, 사실 고치는 것은 안되고 이는 새로운 array를 modicifation이 만들어지게 하고 만드는 것이다. (대체가 아님) 원래의 것은 untouch 되어 no side effect가 된다. 예시 코드를 보자.
def jax_in_place_modify(x):
return x.at[0].set(123)
y = jnp.array([1, 2, 3])
jax_in_place_modify(y)
DeviceArray([123, 2, 3], dtype=int32)
y
DeviceArray([1, 2, 3], dtype=int32)
그럼 실제로 새로운 array를 만드니까 매우 비효율 적이 아닌가? JAX는 JIT을 이용하기 전에 먼저 complied 되기도 한다. 장점들을 잘 섞어서 사용해야 할 듯.
'연구 > JAX' 카테고리의 다른 글
JAX 공부 8 - Automatic Vectorization in JAX, Automatic Differentiation in JAX (0) | 2022.07.19 |
---|---|
JAX 공부 8 - JIT in JAX (0) | 2022.07.19 |
JAX 공부 6 - jax as accelerated numpy, grad (0) | 2022.07.19 |
JAX 공부 5 - Control Flow (0) | 2022.07.18 |
JAX 공부 4 - out of bounds indexing, non array input, random numbers (0) | 2022.07.18 |