https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#specifying-in-axes
이 문서를 공부하며 작성했다
사실 기존의 gpu와 다를 건 거의 없다. 이 문서는 tpu의 적용을 이야기하는데, 기존의 jax.vmap을 jax.pmap으로 바꾼다 생각하면 된다. 차이가 나는 건, parallelized function은 SharedDeviceArray를 리턴하는데, 왜냐면 이 행렬이 parallelizm에서 전부에 공유되었기 때문이다.
in_axes는 내가 헷갈리는 개념이라 다른 곳에서 레퍼를 더 가져온다.
https://jiayiwu.me/blog/2021/04/05/learning-about-jax-axes-in-vmap.html
이걸 보면 확실히 이해되는데, in_axis의 차원은 어떤 식으로 함수를 적용할지를 결정해준다. None의 경우 전체에 적용해주는 것으로 생각하면 된다.
JIT-compile이 자동으로 되기 때문에 pmap을 사용할 때는 jax.jit을 사용할 필요는 없다.
axis name을 지정해주는 것도 좋다. (우리에게는 invisible 함) 그리고 jax.pmap과 jax.vmap이 nested 될 수 있다.
docs에 있는 example을 읽으면 어떻게 parallelize 하는지 알 수 있다.
num_device를 이용해서 data를 나눠주고 param을 복사해준 후 진행한다. 생각보다 직관적이니 직접 읽어보는 것을 추천한다.
from typing import NamedTuple, Tuple
import functools
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * xs + params.bias
return jnp.mean((pred - ys) ** 2)
LEARNING_RATE = 0.005
# So far, the code is identical to the single-device case. Here's what's new:
# Remember that the `axis_name` is just an arbitrary string label used
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
"""Performs one SGD update step on params using the given data."""
# Compute the gradients on the given minibatch (individually on each device).
loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
# Combine the gradient across all devices (by taking their mean).
grads = jax.lax.pmean(grads, axis_name='num_devices')
# Also combine the loss. Unnecessary for the update, but useful for logging.
loss = jax.lax.pmean(loss, axis_name='num_devices')
# Each device performs its own update, but since we start with the same params
# and synchronise gradients, the params stay in sync.
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grads)
return new_params, loss
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise
# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
def split(arr):
"""Splits the first axis of `arr` evenly across the number of devices."""
return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])
# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)
type(x_split)
def type_after_update(name, obj):
print(f"after first `update()`, `{name}` is a", type(obj))
# Actual training loop.
for i in range(1000):
# This is where the params and data gets communicated to devices:
replicated_params, loss = update(replicated_params, x_split, y_split)
# The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
# indicating that they're on the devices.
# `x_split`, of course, remains a NumPy array on the host.
if i == 0:
type_after_update('replicated_params.weight', replicated_params.weight)
type_after_update('loss', loss)
type_after_update('x_split', x_split)
if i % 100 == 0:
# Note that loss is actually an array of shape [num_devices], with identical
# entries, because each device returns its copy of the loss.
# So, we take the first element to print it.
print(f"Step {i:3d}, loss: {loss[0]:.3f}")
# Plot results.
# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))
split을 통해서 데이터를 나눠주고, 나머지는 replicated params를 이용해서 여러 개의 tpu를 이용하게 된다.
'연구 > JAX' 카테고리의 다른 글
JAX 공부 12 - pjit (0) | 2022.07.20 |
---|---|
JAX 공부 11 - stateful computation in JAX (0) | 2022.07.20 |
JAX 공부 9 - Pytrees (0) | 2022.07.19 |
JAX 공부 8 - Automatic Vectorization in JAX, Automatic Differentiation in JAX (0) | 2022.07.19 |
JAX 공부 8 - JIT in JAX (0) | 2022.07.19 |