연구/JAX

JAX 공부 11 - stateful computation in JAX

Chanwoo Park 2022. 7. 20. 15:38
728x90

https://jax.readthedocs.io/en/latest/jax-101/07-state.html

 

Stateful Computations in JAX — JAX documentation

 

jax.readthedocs.io

이 문서를 공부하면서 작성한 글이다. 

 

import jax
import jax.numpy as jnp

class Counter:
  """A simple counter."""

  def __init__(self):
    self.n = 0

  def count(self) -> int:
    """Increments the counter and returns the new value."""
    self.n += 1
    return self.n

  def reset(self):
    """Resets the counter to zero."""
    self.n = 0


counter = Counter()

for _ in range(3):
  print(counter.count())

이걸 하면 당연히 1, 2, 3이 나온다. 그런데, jit을 쓰면

counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  print(fast_count())

1, 1, 1이 나온다. count가 한 번만 컴파일되기 때문이다. return value가 count의 argument에 영향을 안 받으므로 계속 1을 아웃풋 한다. 따라서, argument에 영향을 받게 코드를 수정해주면 된다. (이게 좀 어려운 부분인 듯하다. 무의식적으로 위와 같이 코딩할 것 같다는 생각이 든다. )

from typing import Tuple

CounterState = int

class CounterV2:

  def count(self, n: CounterState) -> Tuple[int, CounterState]:
    # You could just return n+1, but here we separate its role as 
    # the output and as the counter state for didactic purposes.
    return n+1, n+1

  def reset(self) -> CounterState:
    return 0

counter = CounterV2()
state = counter.reset()

for _ in range(3):
  value, state = counter.count(state)
  print(value)
  
state = counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value)

즉, Stateless one을 Stateful Computation으로 바꾸어야 한다는 뜻이다. 이를 general하게 기술하면,

class StatefulClass

  state: State

  def stateful_method(*args, **kwargs) -> Output:

class StatelessClass

  def stateless_method(state: State, *args, **kwargs) -> (Output, State):

로 바꾼다는 뜻이다. 이는 functional programming에서 자주 쓰는 패턴이라고 한다. 

 

글에 내용이 너무 없는데, 실제로 문서에도 내용이 너무 없는거 같다. 쓰면서 익숙해져야 하는 부분인 듯하다.