erodola / DLAI-s2-2021

Teaching material for the course of Deep Learning and Applied AI, 2nd semester 2021, Sapienza University of Rome
35 stars 5 forks source link

5_Autograd_and_Modules: reverse mode + re-use of tensor #17

Open sh3rlock14 opened 3 years ago

sh3rlock14 commented 3 years ago

I would have asked for some clarifications on what was happening "under the hood" while doing the backpass, but I think I figured it out and I would like to share some thoughts with you, so to have a feedback and hoping that this could also dispell some doubts.

This is the code you gave us in the notebook to compute the rate of change at x=42 x = torch.tensor(42., requires_grad=True) x2 = x ** 2 x3 = x ** 3 f = x2 + x3 f.backward(retain_graph=False) x.grad (=53376)

I played a bit with the reassignement of the tensor we derive for, and at the beginning I could not figure it out the results I was obtaining, then I sketched some graphs, derived step-by-step every node and here's what I found:

WhatIF (1): x = torch.tensor(42., requires_grad=True) x = x ** 2 #re-use the same variable x3 = x ** 3 f = x + x3 x.retain_grad() f.backward(retain_graph=False) x.grad (=9335089)`

Backprop - WhatIf (1)

and

WhatIF (2): x = torch.tensor(42., requires_grad=True) y = x ** 2 x = y ** 3 #re-use the same variable f = y + x x.retain_grad() f.backward(retain_graph=False) x.grad (=1)

And what I got so far is that reusing the "base tensor" modifies the computational graph making it shorter, and possibly truncating some computational subgraph as in here:

Backprop - WhatIf (3)

dy/dx won't affect the df/dx and previous assignments to the variable x are ignored. Also, PyTorch keeps deriving and substituing for the original x only if we defined new variables while building f: in this case there won't be any z variable in the graph, so we won't substitute for the original x and the derivation returns 1.

lucmos commented 3 years ago

Hello, thanks for your insights!

I'd just like to make sure we all are on the same page.

The initial snippet is the following:

x = torch.tensor(42., requires_grad=True)
x2 = x ** 2
x3 = x ** 3
f = x2 + x3
f.backward(retain_graph=False)
x.grad

and it yields the following computational graph: image

When you write:

x = torch.tensor(42., requires_grad=True)
x = x ** 2  # re-use the same variable
x3 = x ** 3
f = x + x3
x.retain_grad()
f.backward(retain_graph=False)
x.grad (=9335089)

The computational graph changes: image

But this change is not due to the variable x re-assignment. You are literally re-using the same variable, not the underlying object it refers to -- a tensor in this case.

You did not modify the original tensor torch.tensor(42., requires_grad=True); you just re-assigned the variable x to another tensor (x ** 2) and thus lost all the references to that original object. To be precise, almost all references: there is still some references in the computational graph.

The line x3 = x ** 3 and f = x + x3 are the responsible for this change, since the tensor x refers to now has changed. For example, you can play a bit with temp variables:

x = temp0 =  torch.tensor(42., requires_grad=True)
x = temp1 = temp0 ** 2
x3 = temp0 ** 3
f = temp1 + x3
f.backward(retain_graph=False)

to get the same initial computational graph, even though x has been re-assigned: image


Just for completeness, this is the computational graph for your second example:

x = torch.tensor(42., requires_grad=True)
y = x ** 2
x = y ** 3 #re-use the same variable
f = y + x
x.retain_grad()
f.backward(retain_graph=False)

image

The same as your first example, not only because the operations are the same (which may happen), but because the programs are equivalent:

Example1:

x  =  torch.tensor(42., requires_grad=True)
x = x ** 2
x3 = x ** 3
f = x + x3

Can be simplified to: f = (x ** 2) + ((x ** 2) ** 3)

Example2:

x = torch.tensor(42., requires_grad=True)
y = x ** 2
x = y ** 3 #re-use the same variable
f = y + x

Can be simplified to: f = (x ** 2) + ((x ** 2) ** 3)

(to perform this simplification proceed bottom-up and see to what each variable is referring to.)


Thus I disagree with this sentence:

reusing the "base tensor" modifies the computational graph making it shorter, and possibly truncating some computational subgraph as in here.

You are not modifying the original computational graph with that assignment; you are just re-using x to refer to a different tensor. (If the subsequent code uses the same variable, the computational graph will change since the semantic of the code has changed.)

tldr: variables in python are just references to objects (if they are not base types as int), re-assigning a variable means changing the variable to point to another object. It does not change the object which the variable points to (e.g. a tensor in this case)


I am not sure if that was what you meant, but just to be sure and to clarify a bit for everyone.

Let me know if I missed something or there are other doubts!

sh3rlock14 commented 3 years ago

Thanks for the reply!

I'll try so summarize what I got so far:

  1. I start defining a new tensor x onto which accumulate grads
  2. From it, I can define new tensors y, w, ... ,z being intermediate nodes, each one having their own fn_grad
  3. In the case I reuse the same variable x, I'm allocating into the memory (with a new id) a space storing a new tensor
  4. The computational graph remains the same (since I have may defined some intermediate nodes as in 2.) but
  5. since the f will be defined on the latest assignemnts of its own components (tensors)
  6. reusing the same variables does not mean "shortening the graph", but rather "change the semantic of the code", i.e. how the functions is defined.

EDIT for first comment: from one of your note I noticed I uploead the wrong image for the second WhatIF, anyway, I got your point.

Question: is there some function that plots the computational graphs?

lucmos commented 3 years ago

Yes, seems right to me!

About the computational graphs visualization, I am using torchviz. Usage example:

!pip install torchviz
import torch
from torchviz import make_dot
x  = torch.tensor(42., requires_grad=True)
f = x + x
make_dot(f)

To deepen a bit point 3., this page (from Head First Java, but it's applicable also here apart from some minor changes) does a good analogy:

image

When you re-assigning the x variable:

x = something

You are just changing the reference inside the cup. You did not touch at all the "Dog object", i.e. your tensor