연구/JAX

JAX 공부 10 - Parallel Evaluation in JAX

Chanwoo Park 2022. 7. 20. 15:11
728x90

https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#specifying-in-axes

 

Parallel Evaluation in JAX — JAX documentation

 

jax.readthedocs.io

이 문서를 공부하며 작성했다

 

사실 기존의 gpu와 다를 건 거의 없다. 이 문서는 tpu의 적용을 이야기하는데, 기존의 jax.vmap을 jax.pmap으로 바꾼다 생각하면 된다. 차이가 나는 건, parallelized function은 SharedDeviceArray를 리턴하는데, 왜냐면 이 행렬이 parallelizm에서 전부에 공유되었기 때문이다. 

 

in_axes는 내가 헷갈리는 개념이라 다른 곳에서 레퍼를 더 가져온다.

https://jiayiwu.me/blog/2021/04/05/learning-about-jax-axes-in-vmap.html

 

Learning about JAX :axes in vmap()

Learning about JAX :axes in vmap() Apr 5, 2021 In the recent weeks I’ve been learning about JAX, a Python library for machine learning developed by Google Research team and extensively used by Deepmind. What is JAX and why is it different? JAX uses Numpy

jiayiwu.me

이걸 보면 확실히 이해되는데, in_axis의 차원은 어떤 식으로 함수를 적용할지를 결정해준다. None의 경우 전체에 적용해주는 것으로 생각하면 된다. 

 

JIT-compile이 자동으로 되기 때문에 pmap을 사용할 때는 jax.jit을 사용할 필요는 없다. 

 

axis name을 지정해주는 것도 좋다. (우리에게는 invisible 함) 그리고 jax.pmap과 jax.vmap이 nested 될 수 있다. 

 

docs에 있는 example을 읽으면 어떻게 parallelize 하는지 알 수 있다. 

 

num_device를 이용해서 data를 나눠주고 param을 복사해준 후 진행한다. 생각보다 직관적이니 직접 읽어보는 것을 추천한다. 

from typing import NamedTuple, Tuple
import functools

class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray


def init(rng) -> Params:
  """Returns the initial model params."""
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, ())
  bias = jax.random.normal(bias_key, ())
  return Params(weight, bias)


def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
  """Computes the least squares error of the model's predictions on x against y."""
  pred = params.weight * xs + params.bias
  return jnp.mean((pred - ys) ** 2)

LEARNING_RATE = 0.005

# So far, the code is identical to the single-device case. Here's what's new:


# Remember that the `axis_name` is just an arbitrary string label used
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
  """Performs one SGD update step on params using the given data."""

  # Compute the gradients on the given minibatch (individually on each device).
  loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)

  # Combine the gradient across all devices (by taking their mean).
  grads = jax.lax.pmean(grads, axis_name='num_devices')

  # Also combine the loss. Unnecessary for the update, but useful for logging.
  loss = jax.lax.pmean(loss, axis_name='num_devices')

  # Each device performs its own update, but since we start with the same params
  # and synchronise gradients, the params stay in sync.
  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grads)

  return new_params, loss
  
  # Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise

# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

def split(arr):
  """Splits the first axis of `arr` evenly across the number of devices."""
  return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)

type(x_split)

def type_after_update(name, obj):
  print(f"after first `update()`, `{name}` is a", type(obj))

# Actual training loop.
for i in range(1000):

  # This is where the params and data gets communicated to devices:
  replicated_params, loss = update(replicated_params, x_split, y_split)

  # The returned `replicated_params` and `loss` are now both ShardedDeviceArrays,
  # indicating that they're on the devices.
  # `x_split`, of course, remains a NumPy array on the host.
  if i == 0:
    type_after_update('replicated_params.weight', replicated_params.weight)
    type_after_update('loss', loss)
    type_after_update('x_split', x_split)

  if i % 100 == 0:
    # Note that loss is actually an array of shape [num_devices], with identical
    # entries, because each device returns its copy of the loss.
    # So, we take the first element to print it.
    print(f"Step {i:3d}, loss: {loss[0]:.3f}")


# Plot results.

# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))

split을 통해서 데이터를 나눠주고, 나머지는 replicated params를 이용해서 여러 개의 tpu를 이용하게 된다.