연구/JAX

JAX 공부 3 - Pure function, array updates

Chanwoo Park 2022. 7. 18. 14:29
728x90

 

 

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

 

🔪 JAX - The Sharp Bits 🔪 — JAX documentation

 

jax.readthedocs.io

Pure function 개념에 대해서 좀 헷갈려서 적으면서 정리하고자 한다.

def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect 
  return x

# The side-effects appear during the first run  
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

이를 시행하면 다음과 같다. 

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]

Executing function이 두 번째에 안 나온다. 이는 cache에 함수가 저장되었기 때문이다. 즉, 우리가 생각하는 대로 일어나지 않기 때문에, 우리는 최대한 pure function을 짜야만 한다. 

g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

다른 자료형이 들어왔기 때문에 g가 업데이트 하는 것을 결과를 통해서 확인 가능하다. 

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]

JAX에는 iterator를 쓰지 않는 것을 추천한다. iterator는 JAX의 functional programming model이랑 달라서, 예상치 못한 결과를 낼 수 있다. 내가 함수형 프로그래밍이 처음이라, 이것이 어떻게 진행되는지 좀 더 공부해봐야 알 것 같다. 기존의 내가 쓰던 pytorch 또한 iterator를 많이 썼는데, 이를 어찌해야 하는가... 싶다. 점점 example을 보면서 공부해봐야 알 것 같다. 

 

 

마지막으로 array update는 기존에 한 번 다뤘기 때문에, 예제 코드만 두고 넘어가도록 한다. 

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]