728x90
https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html
Working with Pytrees — JAX documentation
jax.readthedocs.io
이 문서를 공부하며 작성하고 있다.
pytree는 한마디로 모든 걸 짬뽕해둔 거라 할 수 있다. 여기부터 mlp를 training 하는 게 나오므로 코드를 잘 보도록 하자.
처음 init_mlp_parameters 는 params에 dict를 끼워 넣는 것이다. weight와 bias를 넣게 된다.
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
biases=np.ones(shape=(n_out,))
)
)
return params
params = init_mlp_params([1, 128, 128, 1])
jax.tree_map(lambda x: x.shape, params)
우리는 jax.tree_map으로 tree들에 대한 각 function을 적용해준다. 각 원소들에 대해서 x.shape를 적용해준 거 같다.
def forward(params, x):
*hidden, last = params
for layer in hidden:
x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
return x @ last['weights'] + last['biases']
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
LEARNING_RATE = 0.0001
@jax.jit
def update(params, x, y):
grads = jax.grad(loss_fn)(params, x, y)
# Note that `grads` is a pytree with the same structure as `params`.
# `jax.grad` is one of the many JAX functions that has
# built-in support for pytrees.
# This is handy, because we can apply the SGD update using tree utils:
return jax.tree_map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
여기서 궁금한 점은 왜 jit을 update에만해주고 forward에는 안 해주는가이다.
pytree에서 list, tuples, dicts를 제외하고는 모두 leaf로 여겨진다. 그래서 우리가 container class를 사용한다면, leaf로 취급이 된다. 이제 여기서 flatten과 unflatten을 정의해주고 register 해주여야 한다.
from typing import Tuple, Iterable
def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
flat_contents = [container.a, container.b, container.c]
# we don't want the name to appear as a child, so it is auxiliary data.
# auxiliary data is usually a description of the structure of a node,
# e.g., the keys of a dict -- anything that isn't a node's children.
aux_data = container.name
return flat_contents, aux_data
def unflatten_MyContainer(
aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
"""Converts aux data and the flat contents into a MyContainer."""
return MyContainer(aux_data, *flat_contents)
jax.tree_util.register_pytree_node(
MyContainer, flatten_MyContainer, unflatten_MyContainer)
jax.tree_leaves([
MyContainer('Alice', 1, 2, 3),
MyContainer('Bob', 4, 5, 6)
])
이를 이렇게 바꿀 수 있다.
from typing import NamedTuple, Any
class MyOtherContainer(NamedTuple):
name: str
a: Any
b: Any
c: Any
# Since `tuple` is already registered with JAX, and NamedTuple is a subclass,
# this will work out-of-the-box:
jax.tree_leaves([
MyOtherContainer('Alice', 1, 2, 3),
MyOtherContainer('Bob', 4, 5, 6)
])
'연구 > JAX' 카테고리의 다른 글
JAX 공부 11 - stateful computation in JAX (0) | 2022.07.20 |
---|---|
JAX 공부 10 - Parallel Evaluation in JAX (0) | 2022.07.20 |
JAX 공부 8 - Automatic Vectorization in JAX, Automatic Differentiation in JAX (0) | 2022.07.19 |
JAX 공부 8 - JIT in JAX (0) | 2022.07.19 |
JAX 공부 7 - difference from numpy (1) | 2022.07.19 |