Open sh3rlock14 opened 3 years ago
Hello, thanks for your insights!
I'd just like to make sure we all are on the same page.
The initial snippet is the following:
x = torch.tensor(42., requires_grad=True)
x2 = x ** 2
x3 = x ** 3
f = x2 + x3
f.backward(retain_graph=False)
x.grad
and it yields the following computational graph:
When you write:
x = torch.tensor(42., requires_grad=True) x = x ** 2 # re-use the same variable x3 = x ** 3 f = x + x3 x.retain_grad() f.backward(retain_graph=False) x.grad (=9335089)
The computational graph changes:
But this change is not due to the variable x
re-assignment. You are literally re-using the same variable, not the underlying object it refers to -- a tensor in this case.
You did not modify the original tensor torch.tensor(42., requires_grad=True)
; you just re-assigned the variable x
to another tensor (x ** 2
) and thus lost all the references to that original object. To be precise, almost all references: there is still some references in the computational graph.
The line x3 = x ** 3
and f = x + x3
are the responsible for this change, since the tensor x
refers to now has changed.
For example, you can play a bit with temp variables:
x = temp0 = torch.tensor(42., requires_grad=True)
x = temp1 = temp0 ** 2
x3 = temp0 ** 3
f = temp1 + x3
f.backward(retain_graph=False)
to get the same initial computational graph, even though x
has been re-assigned:
Just for completeness, this is the computational graph for your second example:
x = torch.tensor(42., requires_grad=True) y = x ** 2 x = y ** 3 #re-use the same variable f = y + x x.retain_grad() f.backward(retain_graph=False)
The same as your first example, not only because the operations are the same (which may happen), but because the programs are equivalent:
Example1:
x = torch.tensor(42., requires_grad=True)
x = x ** 2
x3 = x ** 3
f = x + x3
Can be simplified to: f = (x ** 2) + ((x ** 2) ** 3)
Example2:
x = torch.tensor(42., requires_grad=True)
y = x ** 2
x = y ** 3 #re-use the same variable
f = y + x
Can be simplified to: f = (x ** 2) + ((x ** 2) ** 3)
(to perform this simplification proceed bottom-up and see to what each variable is referring to.)
Thus I disagree with this sentence:
reusing the "base tensor" modifies the computational graph making it shorter, and possibly truncating some computational subgraph as in here.
You are not modifying the original computational graph with that assignment; you are just re-using x
to refer to a different tensor. (If the subsequent code uses the same variable, the computational graph will change since the semantic of the code has changed.)
tldr: variables in python are just references to objects (if they are not base types as int
), re-assigning a variable means changing the variable to point to another object. It does not change the object which the variable points to (e.g. a tensor in this case)
I am not sure if that was what you meant, but just to be sure and to clarify a bit for everyone.
Let me know if I missed something or there are other doubts!
Thanks for the reply!
I'll try so summarize what I got so far:
x
onto which accumulate gradsy
, w
, ... ,z
being intermediate nodes, each one having their own fn_gradx
, I'm allocating into the memory (with a new id) a space storing a new tensorf
will be defined on the latest assignemnts of its own components (tensors)EDIT for first comment: from one of your note I noticed I uploead the wrong image for the second WhatIF, anyway, I got your point.
Question: is there some function that plots the computational graphs?
Yes, seems right to me!
About the computational graphs visualization, I am using torchviz
. Usage example:
!pip install torchviz
import torch
from torchviz import make_dot
x = torch.tensor(42., requires_grad=True)
f = x + x
make_dot(f)
To deepen a bit point 3., this page (from Head First Java, but it's applicable also here apart from some minor changes) does a good analogy:
x
variable (the "Dog cup" in the picture)x
is a reference (the remote controller)When you re-assigning the x
variable:
x = something
You are just changing the reference inside the cup. You did not touch at all the "Dog object", i.e. your tensor
I would have asked for some clarifications on what was happening "under the hood" while doing the backpass, but I think I figured it out and I would like to share some thoughts with you, so to have a feedback and hoping that this could also dispell some doubts.
This is the code you gave us in the notebook to compute the rate of change at x=42
x = torch.tensor(42., requires_grad=True)
x2 = x ** 2
x3 = x ** 3
f = x2 + x3
f.backward(retain_graph=False)
x.grad (=53376)
I played a bit with the reassignement of the tensor we derive for, and at the beginning I could not figure it out the results I was obtaining, then I sketched some graphs, derived step-by-step every node and here's what I found:
WhatIF (1):
x.grad (=9335089)`
x = torch.tensor(42., requires_grad=True)
x = x ** 2
#re-use the same variablex3 = x ** 3
f = x + x3
x.retain_grad()
f.backward(retain_graph=False)and
WhatIF (2):
x = torch.tensor(42., requires_grad=True)
y = x ** 2
x = y ** 3
#re-use the same variablef = y + x
x.retain_grad()
f.backward(retain_graph=False)
x.grad (=1)
And what I got so far is that reusing the "base tensor" modifies the computational graph making it shorter, and possibly truncating some computational subgraph as in here:
dy/dx
won't affect thedf/dx
and previous assignments to the variablex
are ignored. Also, PyTorch keeps deriving and substituing for the originalx
only if we defined new variables while buildingf
: in this case there won't be anyz
variable in the graph, so we won't substitute for the originalx
and the derivation returns 1.