728x90
https://jax.readthedocs.io/en/latest/jax-101/07-state.html
Stateful Computations in JAX — JAX documentation
jax.readthedocs.io
이 문서를 공부하면서 작성한 글이다.
import jax
import jax.numpy as jnp
class Counter:
"""A simple counter."""
def __init__(self):
self.n = 0
def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
counter = Counter()
for _ in range(3):
print(counter.count())
이걸 하면 당연히 1, 2, 3이 나온다. 그런데, jit을 쓰면
counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
print(fast_count())
1, 1, 1이 나온다. count가 한 번만 컴파일되기 때문이다. return value가 count의 argument에 영향을 안 받으므로 계속 1을 아웃풋 한다. 따라서, argument에 영향을 받게 코드를 수정해주면 된다. (이게 좀 어려운 부분인 듯하다. 무의식적으로 위와 같이 코딩할 것 같다는 생각이 든다. )
from typing import Tuple
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> Tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1
def reset(self) -> CounterState:
return 0
counter = CounterV2()
state = counter.reset()
for _ in range(3):
value, state = counter.count(state)
print(value)
state = counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
value, state = fast_count(state)
print(value)
즉, Stateless one을 Stateful Computation으로 바꾸어야 한다는 뜻이다. 이를 general하게 기술하면,
class StatefulClass
state: State
def stateful_method(*args, **kwargs) -> Output:
을
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):
로 바꾼다는 뜻이다. 이는 functional programming에서 자주 쓰는 패턴이라고 한다.
글에 내용이 너무 없는데, 실제로 문서에도 내용이 너무 없는거 같다. 쓰면서 익숙해져야 하는 부분인 듯하다.
'연구 > JAX' 카테고리의 다른 글
JAX 공부 - 쉬어가기 (0) | 2022.07.20 |
---|---|
JAX 공부 12 - pjit (0) | 2022.07.20 |
JAX 공부 10 - Parallel Evaluation in JAX (0) | 2022.07.20 |
JAX 공부 9 - Pytrees (0) | 2022.07.19 |
JAX 공부 8 - Automatic Vectorization in JAX, Automatic Differentiation in JAX (0) | 2022.07.19 |