티스토리 뷰

06/20/2024

Pytorch의 nn.Module과 비슷하게 JAX의 neural network library인 FLAX에도 nn.Module이 있는데

Pytorch nn.Module과 일대일 대응을 해보려 하다 배운 사실 정리:

__init__ -> setup

forwad -> __call__ (이건 자기 마음이긴 한데, Pytorch 문법대로 인스턴스를 forward로 쓰고 싶음 이렇게 하는 것임)

그런데 setup이 다른 argument를 받지 않는 method라 혼란스러웠는데, nn.Module은 Python 3.7 dataclasses를 가정하고 상속해서 쓰기 때문에 그랬다. 말인즉슨, class variable로 선언해도 인스턴스 variable으로 되기 때문에 __init__ function안 쓰고 setup에 argument를 안넘겨도 문제없음. 이래서 PEP PEP 거리는구나...

참고: https://www.toptal.com/python/python-class-attributes-an-overly-thorough-guide

https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html
https://docs.python.org/3/library/dataclasses.html

nn.Linear -> nn.Dense인데

nn.Dense는 lazy version으로 한 번 call 되었을 때 구현이 되는 듯함. 그래서 input channel dim을 안정해줌. output channel dim 정하는 argument만 있음

https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html

https://stackoverflow.com/questions/78249695/how-can-i-convert-a-flax-linen-module-to-a-torch-nn-module

nn.Dropout  이 좀 골치 아픈데, Pytorch와 다르게 deterministic argument가 필요하고, 전체 모델을 initialize할 때 혹은 eval / train 상태를 바꿀때 train과 dropout rng key를 넘겨야 함.

여기 참고하면 될 듯: https://flax.readthedocs.io/en/latest/guides/training_techniques/dropout.html

여기서 dropout 가지고 뭔 난리를 쳐놨는데 에바인 거 같음. 예시코드랑 설명이랑 대응이 안됨.. 그래도 혹시 모르니: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/arguments.html


06/29/2024

What is the difference between torch.nnSequential and torch.nn.ModuleList?

nn.Sequential은 하나의 neural network로써, argument인 python list 안의 nn.Module들이 다 연결되어있고 forward pass가 정의되어있음. variable에 assign하면 그 variable name으로 forward pass 를 할 수 있다는 뜻.

nn.ModuleList는 그냥 리스트임. argument인 python list안의 nn.Module 들끼리 아무런 연관이 없고, neural network가 아니라서 forward pass도 정의되어있지 않음. 안의 nn.Module들 쓰고 싶으면 indexing하거나 for loop 돌려서 써야 한다는 뜻. nn.ModuleList를 python list 대신에 쓰는 이유는 torch에게 trainable variable임을 알려주기 위해서. params()할 때 parameter들이 return됨.

https://discuss.pytorch.org/t/when-should-i-use-nn-modulelist-and-when-should-i-use-nn-sequential/5463/4

 

When should I use nn.ModuleList and when should I use nn.Sequential?

Not really. Maybe there are some situations where you could use both, but the main idea is the following: In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network t

discuss.pytorch.org

Why JAX/FLAX can use just a python list instead of something like torch.nn.ModuleList?

JAX/FLAX는 특이하게도 그냥 python list쓸 수 있음. FLAX linen module이 python class의 __setattr__ 메소드를 override 하는데, 이 때 python list도 loop돌면서 submdule로 등록 (trainable variable = 'params')가 되도록 설정하는 듯. 소스코드보면 self._register_submodules라는 method를 콜함. https://flax.readthedocs.io/en/latest/_modules/flax/linen/module.html#Module._register_submodules

https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.__setattr__

What is equivalent to torch.no_grad() in JAX/FLAX? How to resume gradient? Should I..?

Background: pose_embedding parameter를 initialize 하는 function을 jax로 바꾸려했는데, 보니까 이런 식으로 torch.no_grad를 씀. 

with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)

torch는 forward pass 때 computation graph를 만드는데, 말인즉슨 tensor.requires_grad=True면 어떤 연산이든 미분가능하면 gradient function을 연산마다 만든다는 뜻임. 그래서 inference할 때 연산마다 gradient function을 만드는 건 비효율적이니까 이런 torch.no_grad() 같은 context manager를 써서 gradient function을 계산하는 걸 방지함. 

그래서 JAX/FLAX에도 이런 gradient function을 만드는 걸 방지하기 위한 context manager가 없을까 고민하고 막 찾아봤는데, 애초에 JAX numpy tensor들은 연산할 때 gradient function을 만들지 않고 requires_grad 같은 attribute도 없는 것 같음. JAX에서는 jax.grad를 불러야 그 때 gradient function들을 만들고 return하는 듯함. 

그래서 vit jax tutorial 도 잘 보면 ( 다 읽진 않아서 확인해봐야함) https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial15/Vision_Transformer.html 

torch.no_grad나 jax.lax.stop_gradient도 없는 듯함. 단, jax.lax.stop_gradient를 써서 jax.grad 를 해도 gradient 관련 operation을 생략하도록 만들 순 있는 듯 함: https://github.com/google/jax/issues/1937

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html


06/30/2024

How are pseudo random numbers handled in JAX?

JAX does not allow a function modifying variables outside of its namespace. So, it is not allowed to set a random seed once at the beginning of the program and sampling different pseudo random numbers with random functions. If that works as Numpy, then the set function is practically modifying variables (random numbers) outside of its namespace. 

'Seed' is an alias of the 'PRNG state' and it is modified at every random function call in Numpy/Pytorch. 

Instead in JAX, the PRGND state is passed as argument to random functions. So if you use the same key (PRNG state), the generated random numbers are the same. So you need to split (or fork) the key befoere every random function call.  

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html
rng = jax.random.PRNGKey(42)

# A non-desirable way of generating pseudo-random numbers...
jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng)
print('JAX - Random number 1:', jax_random_number_1)
print('JAX - Random number 2:', jax_random_number_2)

# Typical random numbers in NumPy
np.random.seed(42)
np_random_number_1 = np.random.normal()
np_random_number_2 = np.random.normal()
print('NumPy - Random number 1:', np_random_number_1)
print('NumPy - Random number 2:', np_random_number_2)
------------------------------------------------------
JAX - Random number 1: -0.18471177
JAX - Random number 2: -0.18471177
NumPy - Random number 1: 0.4967141530112327
NumPy - Random number 2: -0.13826430117118466
------------------------------------------------------

rng, subkey1, subkey2 = jax.random.split(rng, num=3)  # We create 3 new keys
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print('JAX new - Random number 1:', jax_random_number_1)
print('JAX new - Random number 2:', jax_random_number_2)
------------------------------------------------------
JAX new - Random number 1: 0.107961535
JAX new - Random number 2: -1.2226542
------------------------------------------------------

How does JAX/FLAX handle the randomness in neural network training, e.g., nn.Dropout? jax.linen.Dropout does not have a prng state key as an argument.

The dropout layer takes its random key by internally calling
`self.make_rng('dropout')`, which pulls and splits from a PRNG stream named
`'dropout'`. This means when we call `model.apply` we will need to define the
starting key for this PRNG stream. This can be done by passing a dictionary
mapping stream names to PRNG keys, to the `rngs` argument in `model.apply`:
```python
key, x_key = jax.random.split(key)
key, drop_key = jax.random.split(key)
x = jax.random.normal(x_key, (3,3))

model = nn.Dropout(0.5, deterministic=False)
y = model.apply({}, x, rngs={'dropout': drop_key}) # there is no state, just pass empty dictionary :)
x, y
===
https://huggingface.co/blog/afmck/flax-tutorial

07/04/2024

What is XLA? When are jaxpr representations made?

XLA is an open source compiler for machine learning proojects. The jaxpr representations are made when input is given and run. I thought 'jitting' is done when compiling the function and I thought jaxpr representations are made when 'jitting'. But no. 그냥 실제 코드 execution 직전에 jaxpr로 메소드를 break down 후 compile하라는 미래명령 느낌인듯.

참고: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html

더보기

- To view the jaxpr representation of this function, we can use jax.make_jaxpr. Since the tracing depends on the shape of the input, we need to pass an input to the function (here of shape [3]):

- It achieves that by compiling functions just-in-time with XLA (Accelerated Linear Algebra), using their jaxpr representation. 

- Since the jaxpr representation of a function depends on the input shape, the compilation is started once we put the first input in. However, note that this also means that for every different shape we want to run the function, a new XLA compilation is needed. This is why it is recommended to use padding in cases where your input shape strongly varies 

How do you obtain (trainable) parameters that will be passed to the optimizer?

However, in contrast to PyTorch, the parameters are not part of the module. Instead, we can create a set of parameters of the module by calling its init() function. This function takes as input a PRNG state for sampling pseudo-random numbers and an example input to the model, and returns a set of parameters for the module as a pytree.

정확히는 Variable이라고 불리는 dictionary고 그 안에 "params": params dict, "batch_stats": batch stats dict 가 있음. 특: Pytorch의 weight이 여기선 kernel로 불림.

이 Variable dictionary 혹은 {"params": params}를 apply의 argument로 넘기면 됨.

 

What is an optimizer in JAX?

We can create a TrainState which bundles the parameters, the optimizer, and the forward step of the model:

from flax.training import train_state

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)

이건 뭔말인지 모르겠음: Since JAX calculates gradients via function transformations, we do not have functions like backward(), optimizer.step() or optimizer.backward() as in PyTorch. Instead, a optimizer is a function on the parameters and gradients.  

How does training of JAX/FLAX differ from Pytorch?

In contrast to PyTorch, we do not need to explicitly push our model to GPU, since the parameters are already automatically created on GPU. Further, since the model itself is stateless, we do not have a train() or eval() function to switch between modes of e.g. dropout. When necessary, we can add an argument train : bool to the model forward pass.


07/18/2024

 

- JAX 설치하려면 nvidia driver version  >= 525.60.13 for CUDA 12 on Linux.

https://jax.readthedocs.io/en/latest/installation.html

- setup에서 기존 attribute를 바꾸면 안됨

https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.SetAttributeInModuleSetupError

- Python dataclass의 특징인건지 type을 지정안해주면 attribute으로 인식을 못함. ex) act = nn.gelu로 하면 안되고 act: Any = nn.gelu로 해야함 

 

 

'Research (연구 관련)' 카테고리의 다른 글

What are "Spherical harmonics"?  (0) 2024.07.01
What is Equivariance in Computer Vision?  (0) 2024.06.28
What is VLM?  (0) 2024.05.09
What is Variational Score Distillation?  (0) 2024.04.01
What is index building?  (0) 2024.03.30
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
글 보관함