dvgodoy / PyTorchStepByStep

Official repository of my book: "Deep Learning with PyTorch Step-by-Step: A Beginner's Guide"
https://pytorchstepbystep.com
MIT License
841 stars 314 forks source link

extended the example of scaling dot product code #54

Open jdgh000 opened 1 week ago

jdgh000 commented 1 week ago

So I am into attention network , one of the toughest to understand and book so far explains great, however on scaled dot product example, show scaling by 100 the product of ks,q skews. I extended this example by using actual key and query from earlier example (in p262, in order to compute dim which happens to be just 2) and compared non-scaled (p275) and scaled side by side. But on scaled one, it still seems to be a big variance wtih prod vs. 100*prod,, I was hoping to see similar despite prod is multiplied by 100 or must be donig something wrong....:


# This is small example showing difference it makes when dot product is not scaled
# and resulting skew in scoring values.
import copy
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset,  random_split, TensorDataset
import sys
sys.path.append('..')

from common.settings import *
from common.classes import *

from data_generation.square_sequences import generate_sequences
from stepbystep.v4 import StepByStep
from plots.chapter8 import plot_data
from plots.chapter9 import sequence_pred
import matplotlib.pyplot as plt

full_seq=(torch.tensor([[-1,-1],[1,-1], [1,1],[1,-1]]).float().view(1,4,2))
source_seq = full_seq[:, :2]
target_seq = full_seq[:, 2:]

torch.manual_seed(21)
encoder=Encoder(n_features=2, hidden_dim=2)
hidden_seq=encoder(source_seq)
values=hidden_seq
keys=hidden_seq

torch.manual_seed(21)
decoder=Decoder(n_features=2, hidden_dim=2)
decoder.init_hidden(hidden_seq)
inputs=source_seq[:, -1:]
out=decoder(inputs)
query=decoder.hidden.permute(1,0,2)

products = torch.bmm(query, keys.permute(0,2,1))

print("Non scaled")
print(F.softmax(products, dim=-1))
print(F.softmax(100*products, dim=-1))

print("Scaled")
dims=query.size(-1)
scaled_products=products/np.sqrt(dims)
print(F.softmax(scaled_products, dim=-1))
print(F.softmax(100*scaled_products, dim=-1))

Result:

python3 ch9-p275-softmax-skew-due-to-non-scaled-dot-modded.py
printFnc: func:  <function namestr at 0x7fd637ed8c10>
printFnc: func:  <function Linear.__init__ at 0x7fd637eed0d0>
Encoder.init(n_features = 2 , hidden_dim = 2 )
Encoder.forward(X= torch.Size([1, 2, 2]) )
Non scaled
tensor([[[0.3295, 0.6705]]], grad_fn=<SoftmaxBackward0>)
tensor([[[1.3756e-31, 1.0000e+00]]], grad_fn=<SoftmaxBackward0>)
Scaled
tensor([[[0.3770, 0.6230]]], grad_fn=<SoftmaxBackward0>)
tensor([[[1.5053e-22, 1.0000e+00]]], grad_fn=<SoftmaxBackward0>)
dvgodoy commented 1 week ago

HI @jdgh000 ,

Your code seems to be perfectly right. I am guessing the skewed values you're concerned about are just the effect of the softmax function. Softmax does skew values a lot when you increase the scale of the inputs, even if the proportion between the two input values is the same.

If we take two values, say, 0.01 and 0.1, the second is 10x larger than the first, but softmax will return fairly similar results for both:

print(F.softmax(torch.as_tensor([.01, .1]), dim=-1))

tensor([0.4775, 0.5225])

However, if we multiply these values by 10, their proportion remains unchanged, but their overall level is 10x higher, thus affecting how softmax transforms them:

print(F.softmax(torch.as_tensor([.01, .1])*10, dim=-1))

tensor([0.2891, 0.7109])

If we try 100x larger than the initial values, we start seeing the kind of extremely skewed values you mentioned:

print(F.softmax(torch.as_tensor([.01, .1])*100, dim=-1))

tensor([1.2339e-04, 9.9988e-01])

So, it all boils down to the fact that the softmax function, since it exponentiates the inputs in order to transform them into probabilities adding up to one. Does this answer your question?

Best, Daniel

jdgh000 commented 1 week ago

yes, i think so. it may be interesting to pursue this path, but i ;d move on. Just wanted to see if my understanding is correct through some sample code. btw, this appears simples explanation on softmax https://victorzhou.com/blog/softmax/