728x90
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
Pure function 개념에 대해서 좀 헷갈려서 적으면서 정리하고자 한다.
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
이를 시행하면 다음과 같다.
Executing function
First call: 4.0
Second call: 5.0
Executing function
Third call, different type: [5.]
Executing function이 두 번째에 안 나온다. 이는 cache에 함수가 저장되었기 때문이다. 즉, 우리가 생각하는 대로 일어나지 않기 때문에, 우리는 최대한 pure function을 짜야만 한다.
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
다른 자료형이 들어왔기 때문에 g가 업데이트 하는 것을 결과를 통해서 확인 가능하다.
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
JAX에는 iterator를 쓰지 않는 것을 추천한다. iterator는 JAX의 functional programming model이랑 달라서, 예상치 못한 결과를 낼 수 있다. 내가 함수형 프로그래밍이 처음이라, 이것이 어떻게 진행되는지 좀 더 공부해봐야 알 것 같다. 기존의 내가 쓰던 pytorch 또한 iterator를 많이 썼는데, 이를 어찌해야 하는가... 싶다. 점점 example을 보면서 공부해봐야 알 것 같다.
마지막으로 array update는 기존에 한 번 다뤘기 때문에, 예제 코드만 두고 넘어가도록 한다.
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
'연구 > JAX' 카테고리의 다른 글
JAX 공부 6 - jax as accelerated numpy, grad (0) | 2022.07.19 |
---|---|
JAX 공부 5 - Control Flow (0) | 2022.07.18 |
JAX 공부 4 - out of bounds indexing, non array input, random numbers (0) | 2022.07.18 |
JAX 공부 2 - JIT 기초 (0) | 2022.07.18 |
JAX 공부 - jax.numpy, jax.lax 기초 (1) | 2022.07.18 |