jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.96k stars 2.75k forks source link

When I compute gradients and perform updates using the same values in Torch and JAX, I find that after multiple iterations, the inference results differ significantly. #23646

Open CZXIANGOvO opened 4 days ago

CZXIANGOvO commented 4 days ago

Description

Please specify cuda:0 at the very beginning.

import torch
import numpy as np
import os
from network.cv.SSD.backbone_mobilenetv1_pytorch import SSDWithMobileNetV1 as SSD_torch
import jax
import jax
import jax.numpy as jnp
from jax import ops as jops
from jax.nn import one_hot, sigmoid
from jax import lax
import jax.scipy.special as sc
import optax

if "CONTEXT_DEVICE_TARGET" in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU':
    devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
    device = devices[-2]
    final_device = "cuda:" + device
else:
    final_device = 'cpu'

def class_loss_jax(logits, label):
    """Calculate category losses."""
    label = jnp.eye(logits.shape[-1])[label]
    weight = jnp.ones_like(logits)
    pos_weight = jnp.ones_like(logits)
    sigmoid_logits = sc.expit(logits)

    # Binary cross entropy calculation
    term1 = label * jnp.log(sigmoid_logits + 1e-15)
    term2 = (1 - label) * jnp.log(1 - sigmoid_logits + 1e-15)

    loss = - (weight * (term1 * pos_weight + term2))
    sigmoid_cross_entropy = jnp.mean(loss)
    sigmoid = sc.expit(logits)
    p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
    modulating_factor = jnp.power(1 - p_t, 2.0)
    alpha_weight_factor = label * 0.75 + (1 - label) * (1 - 0.75)
    focal_loss = modulating_factor * alpha_weight_factor * sigmoid_cross_entropy
    return focal_loss

def SSDmultibox_jax_cal(params_, pred_loc, pred_label, gt_loc, gt_label, num_matched_boxes):
    mask = jnp.less(0, gt_label).astype(jnp.float32)
    num_matched_boxes = jnp.sum(num_matched_boxes.astype(jnp.float32))
    # Positioning loss
    mask_loc = jnp.tile(jnp.expand_dims(mask, -1), (1, 1, 4))
    diff = jnp.abs(pred_loc - gt_loc)
    smooth_l1 = jnp.where(diff < 1, 0.5 * diff ** 2, diff - 0.5)
    smooth_l1 = smooth_l1 * mask_loc
    loss_loc = jnp.sum(jnp.sum(smooth_l1, -1), -1)
    loss_cls = class_loss_jax(pred_label, gt_label)
    loss_cls = jnp.sum(loss_cls, (1, 2))
    return jnp.sum((loss_cls + loss_loc) / num_matched_boxes)

class loss_SSDmultibox_torch(torch.nn.Module):
    def __init__(self):
        super(loss_SSDmultibox_torch, self).__init__()

    def forward(self, pred_loc, pred_label, gt_loc, gt_label, num_matched_boxes):
        mask = (gt_label > 0).float()
        num_matched_boxes = num_matched_boxes.float().sum()

        # Positioning loss
        mask_loc = mask.unsqueeze(-1).repeat(1, 1, 4)
        smooth_l1 = torch.nn.SmoothL1Loss(reduction='none')(pred_loc, gt_loc) * mask_loc
        loss_loc = smooth_l1.sum(dim=-1).sum(dim=-1)

        # Category loss
        from network.cv.SSD.ssd_utils_torch import class_loss
        loss_cls = class_loss(pred_label, gt_label)
        loss_cls = loss_cls.sum(dim=(1, 2))

        return ((loss_cls + loss_loc) / num_matched_boxes).sum()

image_torch = np.load('./image_torch.npy')
image_torch = torch.from_numpy(image_torch).to(final_device)

pred_loc_torch = np.load('./pred_loc_torch.npy')
pred_loc_torch = torch.from_numpy(pred_loc_torch).to(final_device)
pred_label_torch = np.load('./pred_label_torch.npy')
pred_label_torch = torch.from_numpy(pred_label_torch).to(final_device)
box_torch = np.load('./box_torch.npy')
box_torch = torch.from_numpy(box_torch).to(final_device)
label_torch = np.load('./label_torch.npy')
label_torch = torch.from_numpy(label_torch).to(final_device)
num_match_torch = np.load('./num_match_torch.npy')
num_match_torch = torch.from_numpy(num_match_torch).to(final_device)

model_torch = SSD_torch()
model_torch.train()
model_torch.to(final_device)

learning_rate = 0.02
optimizer_torch = torch.optim.SGD
optimizer_torch = optimizer_torch(model_torch.parameters(), lr=learning_rate)
params_torch = {key: value.detach().cpu().numpy() for key, value in model_torch.state_dict().items()}
params_jax = {name: jnp.array(value, dtype=jnp.float32) for name, value in params_torch.items()}

optimizer_jax = optax.sgd
optimizer_jax = optimizer_jax(learning_rate)
loss_fun_torch = loss_SSDmultibox_torch()
opt_state = optimizer_jax.init(params_jax)

for i in range(0,10):

    pred_loc_torch, pred_label_torch = model_torch(image_torch)

    loss_torch = loss_fun_torch(pred_loc_torch, pred_label_torch, box_torch, label_torch, num_match_torch)

    loss_torch.backward()
    optimizer_torch.step()

    optimizer_torch.zero_grad()
    old_torch_state_dict = model_torch.state_dict()
    torch.save(old_torch_state_dict, './model_weights.pth')

    params_jax_numpy = {name: np.array(value) for name, value in params_jax.items()}
    params_torch_updated = {name: torch.from_numpy(value) for name, value in params_jax_numpy.items()}
    model_torch.load_state_dict(params_torch_updated)

    pred_loc_torch, pred_label_torch = model_torch(image_torch)

    pred_loc_jax = pred_loc_torch.detach().cpu().numpy()
    pred_label_jax = pred_label_torch.detach().cpu().numpy()

    loss_fun_jax = SSDmultibox_jax_cal

    pred_loc_jax = pred_loc_torch.detach().cpu().numpy()
    pred_label_jax = pred_label_torch.detach().cpu().numpy()
    box_jax = box_torch.detach().cpu().numpy()
    label_jax = label_torch.detach().cpu().numpy()
    num_match_jax = num_match_torch.detach().cpu().numpy()
    loss_jax, jax_grads = jax.value_and_grad(loss_fun_jax)(params_jax, pred_loc_jax, pred_label_jax, box_jax,
                                                            label_jax, num_match_jax)

    updates, opt_state = optimizer_jax.update(jax_grads, opt_state, params_jax)
    params_jax = optax.apply_updates(params_jax, updates)

    # jax_grads_distance = chebyshev_distance(old_jax_grads, jax_grads)
    # old_jax_grads = jax_grads
    torch_grads = {key: value.detach().cpu().numpy() for key, value in model_torch.state_dict().items()}

    loaded_state_dict = torch.load('./model_weights.pth')
    model_torch.load_state_dict(loaded_state_dict)

    print('loss_jax/loss_torch:',np.array(loss_jax)/ loss_torch.cpu().detach().numpy())  # 输出: True

屏幕截图 2024-09-15 191310

System info (python version, jaxlib version, accelerator, etc.)

download the code:https://drive.google.com/file/d/1H8uPgPdslVpizmSsif6oK4ey2e-oum9x/view?usp=sharing

!unzip issue3.zip
python issue1.py
jakevdp commented 4 days ago

If it's truly the same model you're running, this is surprising. Given the magnitude of the difference, though, I suspect the models or optimizers differ in important ways: for example, maybe the precise definition of "learning rate" differs between the two implementations.

I don't know either optax or pytorch well enough to guess where that difference might lie, but if it's important to you to debug these differences in the implementations, that's probably where I'd start.

CZXIANGOvO commented 1 day ago

If it's truly the same model you're running, this is surprising. Given the magnitude of the difference, though, I suspect the models or optimizers differ in important ways: for example, maybe the precise definition of "learning rate" differs between the two implementations.

I don't know either optax or pytorch well enough to guess where that difference might lie, but if it's important to you to debug these differences in the implementations, that's probably where I'd start.

We're using the same model.

jakevdp commented 1 day ago

We're using the same model.

Sure, but what I'm suggesting is that you may not be using the same optimizer.