728x90
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html을 공부하면서 작성 중이다.
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
try:
f(2)
except Exception as e:
print("Exception {}".format(e))
안 되는 것을 확인 가능하다. 이걸 작성하면서 살짝 이해한 거 같음. jit의 특성이랑 꽤 관련이 있다. 어쨌든 전에 cache에 저장 도면서 머가 될지를 결정해야 하는데, trace variable에 문제가 생기기 때문이다. JIT은 ShapedArray라는 abstraction level로 코드를 trace 한다. 그래서 분기점이 나타난다면 (if 문같이) True, False를 나타낼 수 없기 때문에 에러를 발생시킨다. 위의 문제를 해결해주기 위해서 다음과 같은 일을 해준다.
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
function이 global side effect를 가지고 있을 경우에도, 마찬가지다. print문을 안에 넣어둘 경우 이상한 것이 발생 가능하다. (global side effect) 흠... 갑자기 드는 의문이, 이렇다면 어떻게 디버깅을 잘 해내야 할까 싶다.
4가지 모드의 컨트롤 가능 불가능의 영역이 있다.
cond는 다음과 같은 2개의 코드가 동치라고 한다.
while의 경우 forward mode differentiable인데 reverse의 경우 불가능하다고 한다.
For의 경우 다음과 같다.
Nan 관련 이야기도 있는데, 그냥 내가 읽기에는 다음만 추가하면 되는 것 같다.
'연구 > JAX' 카테고리의 다른 글
JAX 공부 7 - difference from numpy (1) | 2022.07.19 |
---|---|
JAX 공부 6 - jax as accelerated numpy, grad (0) | 2022.07.19 |
JAX 공부 4 - out of bounds indexing, non array input, random numbers (0) | 2022.07.18 |
JAX 공부 3 - Pure function, array updates (0) | 2022.07.18 |
JAX 공부 2 - JIT 기초 (0) | 2022.07.18 |