연구/JAX

JAX 공부 9 - Pytrees

Chanwoo Park 2022. 7. 19. 19:14
728x90

https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html

 

Working with Pytrees — JAX documentation

 

jax.readthedocs.io

이 문서를 공부하며 작성하고 있다. 

 

pytree는 한마디로 모든 걸 짬뽕해둔 거라 할 수 있다. 여기부터 mlp를 training 하는 게 나오므로 코드를 잘 보도록 하자.

 

처음 init_mlp_parameters 는 params에 dict를 끼워 넣는 것이다. weight와 bias를 넣게 된다. 

 

import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])
jax.tree_map(lambda x: x.shape, params)

우리는 jax.tree_map으로 tree들에 대한 각 function을 적용해준다. 각 원소들에 대해서 x.shape를 적용해준 거 같다. 

 

def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

LEARNING_RATE = 0.0001

@jax.jit
def update(params, x, y):

  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of the many JAX functions that has
  # built-in support for pytrees.

  # This is handy, because we can apply the SGD update using tree utils:
  return jax.tree_map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

여기서 궁금한 점은 왜 jit을 update에만해주고 forward에는 안 해주는가이다.

 

 

 

pytree에서 list, tuples, dicts를 제외하고는 모두 leaf로 여겨진다. 그래서 우리가 container class를 사용한다면, leaf로 취급이 된다. 이제 여기서 flatten과 unflatten을 정의해주고 register 해주여야 한다. 

from typing import Tuple, Iterable

def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
  """Returns an iterable over container contents, and aux data."""
  flat_contents = [container.a, container.b, container.c]

  # we don't want the name to appear as a child, so it is auxiliary data.
  # auxiliary data is usually a description of the structure of a node,
  # e.g., the keys of a dict -- anything that isn't a node's children.
  aux_data = container.name
  return flat_contents, aux_data

def unflatten_MyContainer(
    aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
  """Converts aux data and the flat contents into a MyContainer."""
  return MyContainer(aux_data, *flat_contents)

jax.tree_util.register_pytree_node(
    MyContainer, flatten_MyContainer, unflatten_MyContainer)

jax.tree_leaves([
    MyContainer('Alice', 1, 2, 3),
    MyContainer('Bob', 4, 5, 6)
])

이를 이렇게 바꿀 수 있다.

from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# Since `tuple` is already registered with JAX, and NamedTuple is a subclass,
# this will work out-of-the-box:
jax.tree_leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
])