facebookresearch / dlrm

An implementation of a deep learning recommendation model (DLRM)
MIT License
3.72k stars 825 forks source link

Loss is way to high when applying QR Embedding with add operation #326

Closed YoungsukKim12 closed 1 year ago

YoungsukKim12 commented 1 year ago

Hello, I'm interested in training DLRM with QR embedding. I'm having some trouble reproducing the same result written in the paper.

When I use 'mult' operation, I get the similar result in loss. But when using 'add' operation, the loss goes up to near 75 and doesn't come down to the normal range. (normal range ~= 0.5) I want to use 'add' operation for my research purpose, but I'm not able to figure out what is making this unexpected result. Is there any way to solve this problem?

hjmshi commented 1 year ago

Hi Gray, thanks for your interest in the QR work! Can I confirm what you are using for the initialization of the embedding bags within QR for each operation and which optimizer you are using?

As a heuristic, it is best to initialize the Q and R embedding bags such that its summation or multiplication yields something similar or equal to the standard initialization for standard embedding bag. We also used the Adagrad and AMSGrad optimizers in our experiments.

YoungsukKim12 commented 1 year ago

Hello hjmshi, thanks for reply.

I used adagrad as an optimizer. For addressing q & r vector initialization issue while using summation, I fixed _resetparameters() function inside tricks/qr_embedding_bag.py. I changed the code like below.

Original :

def reset_parameters(self):
        nn.init.uniform_(self.weight_q, np.sqrt(1 / self.num_categories))
        nn.init.uniform_(self.weight_r, np.sqrt(1 / self.num_categories))

Changed :

def reset_parameters(self):
        nn.init.uniform_(self.weight_q, np.sqrt(1 / self.num_categories)/2)
        nn.init.uniform_(self.weight_r, np.sqrt(1 / self.num_categories)/2)

Is there anything else that I should fix?

hjmshi commented 1 year ago

Hi Gray, that should be sufficient if you're using the 'add' operator. Can you let me know if it works?

Note that the previous reset_parameters makes sense if you are using the 'concat' operator.

YoungsukKim12 commented 1 year ago

Hello hjmshi, sorry for late reply.

I tried, but the loss still doesn't decrease. if I divide the summation of q, r vector (located inside forward() function inside tricks/qr_embedding_bag.py), loss goes down but still doesn't decrease as much as the result on the paper.

Original :

        if self.operation == 'concat':
            embed = torch.cat((embed_q, embed_r), dim=1)
        elif self.operation == 'add':
            embed = embed_q + embed_r
        elif self.operation == 'mult':
            embed = embed_q * embed_r

Changed:

        if self.operation == 'concat':
            embed = torch.cat((embed_q, embed_r), dim=1)
        elif self.operation == 'add':
            embed = embed_q + embed_r / 8
        elif self.operation == 'mult':
            embed = embed_q * embed_r
hjmshi commented 1 year ago

Hi @GrayGlacier, sorry, somehow this response got lost in the shuffle...

Let us focus on the add case for now. If you are going to change the code, you can change it as follows:

        if self.operation == 'concat':
            embed = torch.cat((embed_q, embed_r), dim=1)
        elif self.operation == 'add':
            embed = (embed_q + embed_r) / 2
        elif self.operation == 'mult':
            embed = embed_q * embed_r

For the initialization, using the original reset_parameters function is fine:

def reset_parameters(self):
        nn.init.uniform_(self.weight_q, np.sqrt(1 / self.num_categories))
        nn.init.uniform_(self.weight_r, np.sqrt(1 / self.num_categories))

Have you tuned Adagrad as well? I recall trying some different learning rates. From my recollection, I believe we actually were using AMSGrad to apply to QR embedding in our experiments, and that this works better than Adagrad. The discrepancy may also arise from the optimizer choice and hyperparameters.

YoungsukKim12 commented 1 year ago

Hello @hjmshi thanks for your continous support. Tuning Adagrad definitely helped. Model was converged and the train loss reached near 0.5. Thanks! However, I want to make the loss smaller, so I tried using Amsgrad as you recommended.

I tried using AMSGrad by enabling amsgrad option in torch.optim.Adam. In order to do that, I changed the code inside run() function at dlrm_s_pytorch.py like the below:

        if args.optimizer == 'Adam':
            optimizer = torch.optim.Adam(parameters, lr=args.learning_rate, amsgrad=True)
        else:
            optimizer = opts[args.optimizer](parameters, lr=args.learning_rate)

But I got the following result:

  File "/home/youngsuk95/.conda/envs/yskim/lib/python3.9/site-packages/torch/optim/adam.py", line 107, in step
    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
RuntimeError: Adam does not support sparse gradients, please consider SparseAdam instead

So, I changed to SparseAdam optimizer but it also produces somewhat similar output:

  File "/home/youngsuk95/.conda/envs/yskim/lib/python3.9/site-packages/torch/optim/sparse_adam.py", line 80, in step
    raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
RuntimeError: SparseAdam does not support dense gradients, please consider Adam instead

Is there any way that I can use amsgrad or Adam?

hjmshi commented 1 year ago

Hi @GrayGlacier, unfortunately, the default PyTorch implementation of AMSGrad provided in Adam does not support sparse parameters. During my internship, I had actually implemented my own version of AMSGrad that is compatible with sparse parameters for experimentation purposes. Is it necessary to use AMSGrad in order to get what you need, or is a 0.5 loss good enough?

YoungsukKim12 commented 1 year ago

Hello @hjmshi , I think 0.5 is good enough if AMSGrad has to be implemented. Thanks for your kind answers!

hjmshi commented 1 year ago

Sounds good, don't hesitate to reach out if you have any other questions. :)