전체 글 45

JAX 공부 - 쉬어가기

Getting Started 관련 문서를 전부 읽었다. 이제 Reference Documentation, Advanced JAX Tutorial들을 하나씩 자세히 볼 예정이다. 텐서 플로우 / 파이 토치 / JAX를 적당히 섞어가면서 코딩하는 것이 좋아 보인다. JAX를 공부하면서 느낀 것은 JAX 자체를 공부할 때 이 원리를 정확하게 파악해야만 한다는 것이다. JIT이라는 개념을 이해하는 것도 매우 어렵고, 이를 코딩에 적용하는 것도 어렵다. 따라서 나같이 함수형 프로그래밍이 처음인 사람들은 이게 왜 되는지 항상 고민해보고 내가 생각한 게 과연 함수형 프로그래밍 향 생각인지를 고민해야 하는 것 같다.

연구/JAX 2022.07.20

JAX 공부 12 - pjit

https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html Introduction to pjit — JAX documentation jax.readthedocs.io 여러 device에서 XLA를 돌리기 위해서 있는 툴이다. 같은 프로그램을 N개의 device에서 돌리기 위해서 있고, XLA SPMD partitioner에 대해서 사용한다고 한다. Mesh로 나누게 된다. Mesh 내에서 in_axis resources, out_axis_resources를 잘 대응시켜야 한다. mesh_shape = (4, 2) devices = np.asarray(jax.devices()).reshape(*mesh_shape) # 'x', 'y' axis names are ..

연구/JAX 2022.07.20

JAX 공부 10 - Parallel Evaluation in JAX

https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#specifying-in-axes Parallel Evaluation in JAX — JAX documentation jax.readthedocs.io 이 문서를 공부하며 작성했다 사실 기존의 gpu와 다를 건 거의 없다. 이 문서는 tpu의 적용을 이야기하는데, 기존의 jax.vmap을 jax.pmap으로 바꾼다 생각하면 된다. 차이가 나는 건, parallelized function은 SharedDeviceArray를 리턴하는데, 왜냐면 이 행렬이 parallelizm에서 전부에 공유되었기 때문이다. in_axes는 내가 헷갈리는 개념이라 다른 곳에서 레퍼를 더 가져온다. https:/..

연구/JAX 2022.07.20

JAX 공부 9 - Pytrees

https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html Working with Pytrees — JAX documentation jax.readthedocs.io 이 문서를 공부하며 작성하고 있다. pytree는 한마디로 모든 걸 짬뽕해둔 거라 할 수 있다. 여기부터 mlp를 training 하는 게 나오므로 코드를 잘 보도록 하자. 처음 init_mlp_parameters 는 params에 dict를 끼워 넣는 것이다. weight와 bias를 넣게 된다. import numpy as np def init_mlp_params(layer_widths): params = [] for n_in, n_out in zip(layer_widths[:-1],..

연구/JAX 2022.07.19

JAX 공부 8 - Automatic Vectorization in JAX, Automatic Differentiation in JAX

https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html Automatic Vectorization in JAX — JAX documentation jax.readthedocs.io 이 문서를 공부하면서 작성한다. jax.vmap을 이용해서 vectorized implementation을 진행한다. 이는 jax.jit과 비슷하게 automatically adding batch axes 한다. 그냥 잘 쓰면 되는 듯... jit이랑 vmap은 composable 하다고 한다. 위의 문서는 내용이 없는 듯.. https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html Advanc..

연구/JAX 2022.07.19

JAX 공부 8 - JIT in JAX

JIT을 공부하기 전에, 내가 위키에서 JIT의 정의를 가져왔다. 사실 JIT이 뭔지 정확히 감이 안 잡혔는데, 위키를 보니 감을 잡아가는 것 같다. JIT 컴파일(just-in-time compilation) 또는 동적 번역(dynamic translation)은 프로그램을 실제 실행하는 시점에 기계어로 번역하는 컴파일 기법이다. 전통적인 입장에서 컴퓨터 프로그램을 만드는 방법은 두 가지가 있는데, 인터프리트 방식과 정적 컴파일 방식으로 나눌 수 있다. 이 중 인터프리트 방식은 실행 중 프로그래밍 언어를 읽어가면서 해당 기능에 대응하는 기계어 코드를 실행하며, 반면 정적 컴파일은 실행하기 전에 프로그램 코드를 기계어로 번역한다. JIT 컴파일러는 두 가지의 방식을 혼합한 방식으로 생각할 수 있는데, 실..

연구/JAX 2022.07.19

JAX 공부 7 - difference from numpy

https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html JAX As Accelerated NumPy — JAX documentation jax.readthedocs.io 을 공부하면서 작성하고 있다. JAX는 함수형 프로그래밍을 하기 위해서 만들어졌다는 게 중요하다. 아... 근데 내가 직접 이 문서를 볼수록 나랑 잘 맞는 거 같다. 나는 학창 시절에도 python보다 c++ 이 더 편했다. python의 ambiguity가 나를 더 힘들게 했는데, 점점 재미가 붙는 거 같다. 아무튼.. 계속해서 설명하면 JAX를 익히는데 함수형 프로그래밍을 잘 알 필요는 없다. 사실 나도 함수형 프로그래밍을 모른다. 다만 functional programmin..

연구/JAX 2022.07.19

JAX 공부 6 - jax as accelerated numpy, grad

https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html JAX As Accelerated NumPy — JAX documentation jax.readthedocs.io 공부하면서 정리하고 있다. DeviceArray : JAX가 array를 표현하는 방법이다. JAX는 다른 backend (CPU, GPU, TPU)에서 같은 코드로 돌릴 수 있다. asynchronous dispatch로 인해서 (https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch 참고) 뒤에 추가적인 코드를 붙인다. 쉽게 asynchronous dispatch를 이야기하면 계산을 기다리지..

연구/JAX 2022.07.19

Application: SOP

(주 1) SOP는 공개하지는 않았지만, 제 메일로 자신이 '딥러닝 이론'을 전공하고 싶은 경우, 자신의 CV랑 짤막한 소개를 보내주시면 제 SOP를 보내드리겠습니다. 제 메일은 cpark97@mit.edu입니다. (주 2) 과마다 매우 다를 것이기 때문에 참고만 하시면 좋을 것 같습니다. SOP는 내가 생각했을 때 자신이 지금까지 쌓아온 실적과 비슷하게 중요한 요소이다. 그 이유는 첫째, 생각보다 지원자들 중에서 현재의 Literature를 파악한 사람이 잘 없다. 사실 학석사 정도에서 현재 리터레쳐를 잘 파악하는 것은 거의 불가능이다. 내 sop를 지금 와서 돌아보면 부끄럽다. 당연한 것이기도 하지만, 그만큼 조사를 더 해서 어떤 일이 현재 일어나고 있는지를 정확하게 파악하고 그를 넘어서 앞으로 사람..