728x90
https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html
Introduction to pjit — JAX documentation
jax.readthedocs.io
여러 device에서 XLA를 돌리기 위해서 있는 툴이다.
같은 프로그램을 N개의 device에서 돌리기 위해서 있고, XLA SPMD partitioner에 대해서 사용한다고 한다.
Mesh로 나누게 된다. Mesh 내에서 in_axis resources, out_axis_resources를 잘 대응시켜야 한다.
mesh_shape = (4, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
# 'x', 'y' axis names are used here for simplicity
mesh = maps.Mesh(devices, ('x', 'y'))
mesh
쓰여있는 코드들을 그대로 사용하면 될 것 같다. 원리만 알고 넘어가도 충분할 것 같아서 내가 실습해보지는 않았다. 실제로 돌리는 코드를 알아야 할 것 같은데, 일단 원리는 결국 내가 직접 assign을 해주는 것이라 보면 된다. 실제 diffusion 같은 것에서는 multicore GPU를 사용하니 이 example을 보면서 공부해보아야겠다.
'연구 > JAX' 카테고리의 다른 글
JAX 공부 - 쉬어가기 (0) | 2022.07.20 |
---|---|
JAX 공부 11 - stateful computation in JAX (0) | 2022.07.20 |
JAX 공부 10 - Parallel Evaluation in JAX (0) | 2022.07.20 |
JAX 공부 9 - Pytrees (0) | 2022.07.19 |
JAX 공부 8 - Automatic Vectorization in JAX, Automatic Differentiation in JAX (0) | 2022.07.19 |