티스토리 뷰

GAN을 공부하다가 tutorial 코드에 netD(fake.detach())가 어떤 원리인 지 이해가 안 갔다. 의도야 설명에 나온대로 netG에 backpropagation이 안되도록, 즉 첫번째 스텝에서는 netD만 학습하려는 것이라는 건 알겠다. (code: pytorch gan tutorial)

## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()

그런데 나는 detach()가 Tensor의 requires_grad를 False로 만드는 것으로 알고 있다. 그렇다면 netD(fake.detach())는 netD의 weight들의 requires_grad를 False로 만들어 netD를 학습하지 못하게 하는 것 아닌가? 라는 착각을 했다.

pytorch에서 쓰이는 computation graph, requires_grad, backpropagation, gradient 등을 이해하지 못하고 있었기 때문에 생긴 의문이었다.

결론은, fake.detach()는 netD의 weight들(model parameters)의 requires_grad에 아무런 영향을 주지 않으며 fake 본인의 requires_grad를 False로 만듦으로써 errD_fake.backward()가 실행되도 fake 본인의 grad는 여전히 None이고, detach()에 의해 grad_fn도 None으로 바뀌었기 때문에 netG의 weight들로 backpropagation도 할 수 없다. (그리고 사실 GAN이 아니라 일반적인 one stage Neural Network에서도 input의 requires_grad가 True일 필요는 전혀 없다. 우리가 학습하려는 것은 오직 model parameters, 즉 weight들이기 때문에 weight Tensor의 requires_grad만 True이면 된다)

GAN의 Computation Graph, Output == errD_fake, F=fake라 치자.

즉 위의 사진에서 fake.detach()를 하면 F의 grad_fn인 F/W은 다 None이 되고, errD_fake.backward()를 하면 원래 Out/F가 chain rule에 의해 계산이 되어야 하는데 F의 requires_grad가 False니까 None이 된다. 즉 GAN의 computation graph에서 밑의 사진의 노란색 박스 부분이 빠지게 되는 효과가 생기는 것이다 == 'netG에 backpropagation이 안된다'

노란색 박스 부분이 GAN의 computation graph에서 빠지게 됨

이것을 이해하기까지 하루가 걸렸는데, 내가 놓친 기초들이 있었기 때문이고 위의 내용을 이해하기 위해 필수적인 부분들은 다음과 같다.

  • pytorch Tensor 객체의 detach()의 효과 
  • Compuation Graph에 대한 이해
  • forward pass와 backpropagation의 의미
  • pytorch Tensor 객체의 grad, grad_fn의 의미와 그것들과 backpropagtion의 관계
  • pytorch Tensor의 requires_grad가 어떻게 설정되는 지

당연히 위의 내용들을 아예 몰랐던 것이 아니지만 '정확히' 알지 못했기 때문에 fake.detach()를 이해하는데 오래걸렸다.

pytorch Tensor 객체의 detach()의 효과

Pytorch document에는 'Returns a new Tensor, detached from the current graph. The result will never require gradient.'라고 나와있다. 지금이야 이해가 가지만 모를 땐 좀 부족한 설명이다. 

 The detach() method constructs a new view on a tensor which is declared not to need gradients, i.e., it is to be excluded from further tracking of operations, and therefore the subgraph involving this view is not recorded. link

위의 설명이 pytorch detach()에 대한 가장 보편적인 설명인데, 밑줄 친 부분이 헷갈려 문제가 생겼다. excluded from further tracking이라는 것이 내가 위에 만든 Computation Graph에서 netD 파트의 gradient를 추적하지 않겠다는 얘기인 줄 알았다. 즉 subgraph involving this view is not recorded 가 노란박스를 빼는게 아니라 그 위의 부분을 computation graph에서 빼겠다는 걸로 해석되서 매우 헷갈렸다. 실제로는 

excluded from further tracking of operations 
-> fake.requires_grad = False 
-> forward pass할 때 
O1/F, O2/F를 저장하지 않음(O1.grad_fn, O2.grad_fn에)
-> errD_fake.backward()를 해도 fake.grad = None

subgraph involving this view is not recorded 
-> fake.grad_fn = None
-> netG로 backpropagation이 넘어가지 않음

사실 grad=None인데, grad_fn이 None이 아니면 에러가 날 것이기에 굳이 둘을 detach()의 별개의 효과로 구분지을 필요는 없을 것 같다.

Compuation Graph에 대한 이해

사실 위의 Computation Graph 모형도에서 작은 노란박스 초록박스, 파란색 노드들이 무엇인지 조금이라도 이해가 안가면 Computation Graph에 나처럼 이해가 없는 사람이다. 파란색 노드는 Operation이 이루어지는 Tensor 변수들로 화살표가 곱셉, 덧셈 등의 operation이고 노드는 화살표(Operation)에 대한 input이자 output이다.  작은 노란박스는 Forward Pass를 할 때 Tensor의 grad_fn, 즉 operation output에 대한 input의 미분함수이다. 실제 이 grad_fn을 들고있는 Tensor는 output Tensor이다. 그리고 작은 초록박스가 바로 chain rule에 의해 backpropagation으로 얻는 Tensor의 gradient로 optimizer.step()에 의해 해당 Tensor를 update하는 값이다. Tensor, grad_fn, grad의 한 세트는 밑의 사진과 같다. 이 링크에 Compuation grap와 backpropagation에 대해 정말 잘 설명되어있다.

Tensor, grad_fn, grad의 한 세트. O1의 attribute

forward pass와 backpropagation의 의미

내가 forward pass에 대해 크게 간과한 부분이다. 단순히 output을 계산하는 과정을 forward pass라 하는 줄 알았는데, 바로 이 forward pass를 할 때 computation graph를 만들게 된다. 즉 내가 그린 모형도에서 초록색박스, Tensor의 gradient,를 뺀 나머지 부분들을 만드는 것이다!! 

When computing the forwards pass, autograd simultaneously performs the requested computations and builds up a graph representing the function that computes the gradient (the .grad_fn attribute of each torch.Tensor is an entry point into this graph). When the forwards pass is completed, we evaluate this graph in the backwards pass to compute the gradients. link

즉 forward pass를 할 때 각 Tensor의 grad_fn을 저장했다가 backpropagation을 할 때 이 grad_fn으로 각 Tensor의 gradient를 chain rule로 계산해내는 것이다. inference를 할 때 with torch.no_grad()나 detach()해서 메모리를 아낀다는 말이 곧 Tensor의 grad_fn을 저장하지 않아 아낀다는 것이었다. (당연히 저장하는 프로세스가 없으니까 속도도 빨라지고) Training과 달리 backpropagation을 안하니까 grad_fn이 필요가 없기 때문이다. 

pytorch Tensor 객체의 grad, grad_fn의 의미와 그것들과 backpropagtion의 관계

쓰다보니 이 부분은 자동으로 설명이 된 것 같다.

pytorch Tensor의 requires_grad가 어떻게 설정되는 지

먼저 Tensor의 requires_grad는 default가 False이다.
nn.Module을 통해 만든 model parameters의 requires_grad는 default가 True이다. link

그렇다면 다른 Tensor에서 계산되어 나온 Output Tensor의 requires_grad는? 하나라도 True면 Ouput의 requires_grad도 True! link

If there’s a single input to an operation that requires gradient, its output will also require gradient. Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.

>>> x = torch.randn(5, 5)  # requires_grad=False by default
>>> y = torch.randn(5, 5)  # requires_grad=False by default
>>> z = torch.randn((5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True

netD(fake.detach())의 이유에 대해 찾다보니 내가 궁금한 이유에 대해서는 사람들이 다 알고있는 지 질문하지 않고, 
optimizerD.step()은 결국 netD의 parameters만 업데이트 할텐데 굳이 detach()를 하는 이유가 뭐냐는 질문이 많았다. 

이유는 netD를 학습하기 위해 쓴 fake를 netG를 학습하는 단계에서 재활용하기 때문인데, netG 파트에 해당하는 computation graph를 보존하기 위함이다. Default로 .backward()함수는 retain_graph=False인데, 때문에 backward()를 한 번 실행하면 computation graph가 다 날아가버린다 == grad_fn이 다 없어진다. 그래서 같은 그래프에 backward()를 다시 하려 그러면 오류가 난다.
(처음에 .backward(retain_graph=True)를 하지 않는 이상)

##### Update D network (윗부분 생략)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()

D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
# Update D
optimizerD.step()

##### Update G network
netG.zero_grad()
label.fill_(real_label)  
output = netD(fake).view(-1) # fake 재활용
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()

위에서 detach()에 대해 설명할 때 혼동되는 부분이 있는데, 참고로 detach()는 in-place함수가 아니라 requires_grad, grad_fn이 각각 False, None인 "새로운" Tensor를 리턴한다!

내가 옮겨 적은 plopd의 답변: https://github.com/pytorch/examples/issues/116

pytorch의 backward(retain_graph=False)에 대한 깔끔한 설명: https://jdhao.github.io/2017/11/12/pytorch-computation-graph/

 

Computational Graphs in PyTorch

PyTorch is a relatively new deep learning library which support dynamic computation graphs. It has gained a lot of attention after its official release in January. In this post, I want to share what I have learned about the computation graph in PyTorch. Wi

jdhao.github.io

 

'Research (연구 관련)' 카테고리의 다른 글

Ad hoc categories  (0) 2019.05.23
Why do we need validation set?  (0) 2019.05.15
Transposed Convolution  (0) 2019.05.10
nohup on other file  (0) 2019.04.29
2D 이미지와 퓨리에 변환  (1) 2019.04.23
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/04   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30
글 보관함