luyug / GradCache

Run Effective Large Batch Contrastive Learning Beyond GPU/TPU Memory Constraint
Apache License 2.0
327 stars 19 forks source link

`TypeError: __call__() takes 2 positional arguments but 3 were given` when using `@cached` and `@autocast` #21

Closed aaprasad closed 1 year ago

aaprasad commented 1 year ago

Hi there, I'm getting a strange error when trying to use autocasting and gradient caching at the same time and was wondering if you had some insight. For context, I'm trying to train an audio-visual clip model as follows: Model:

class AVSimCLR(nn.Module):
    def __init__(self, 
                 audio_encoder_cfg: dict = {},
                 visual_encoder_cfg: dict = {},
                 projector_cfg: dict = {},
                 loss_cfg: dict = {}, 
                 optimizer_cfg: dict = {}, 
                 scheduler_cfg: dict = {}):
        # TODO: Generalize this to any encoder
        super().__init__()
        self.save_hyperparameters()

        img_shape = visual_encoder_cfg.pop("img_shape")

        audio_encoder_name = audio_encoder_cfg.pop("name")
        audio_shape = audio_encoder_cfg.pop("audio_shape")

        self.audio_encoder = ResNet1D(in_channels=audio_shape[-1], base_filters=128, kernel_size=3, stride=1, groups=1, n_block=8, use_do=False , downsample_gap=2, increasefilter_gap=3)

        self.audio_projector = ProjectionHead(in_dims = self._get_hidden_dim(encoder='audio', input_shape=audio_shape), **projector_cfg) #three layer MLP

        self.visual_encoder = torch.nn.Sequential(*list(torchvision.models.resnet18(weights=None).children())[:-1])

        self.visual_projector = ProjectionHead(in_dims = self._get_hidden_dim(encoder = 'visual', input_shape = img_shape), **projector_cfg) # three layer mlp

        self.loss = CAVLoss(self.device)

        self.optimizer_cfg = optimizer_cfg

        self.scheduler_cfg = scheduler_cfg

    def _get_hidden_dim(self,encoder: str, input_shape):
        if encoder.lower() == 'audio':
            x = torch.rand(1,*input_shape)
            if isinstance(self.audio_encoder, ResNet1D):
                x = x.permute(0,2,1)
            with torch.no_grad():
                z = self.audio_encoder(x)
        elif encoder.lower() == 'visual':
            x = torch.rand(1,*input_shape)
            with torch.no_grad():
                z = self.visual_encoder(x)
            z = z.reshape(z.shape[0],z.shape[1])

        return z.shape[-1]

    def forward_visual(self, imgs):
        img_emb = self.visual_encoder(imgs)
        img_emb = img_emb.reshape(img_emb.shape[:2])
        img_projected = self.visual_projector(img_emb)
        return img_projected

    def forward_audio(self, audio):
        if isinstance(self.audio_encoder,ResNet1D):
            audio = audio.permute(0,2,1)

        audio_emb = self.audio_encoder(audio)

        audio_projected = self.audio_projector(audio_emb)
        return audio_projected

    def forward(self, imgs, audio):

        img_projected = self.forward_visual(imgs)

        audio_projected = self.forward_audio(audio)

        return img_projected, audio_projected

Loss:

  import torch
import numpy as np
import gradcache as gc
from gradcache.functional import cached, cat_input_tensor
from torch.cuda.amp import autocast

@cat_input_tensor
@autocast
def cav_loss(image_emb: torch.Tensor, audio_emb: torch.Tensor):

    logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
    audio_emb = audio_emb / audio_emb.norm(dim=-1, keepdim=True)

    image_emb = image_emb / image_emb.norm(dim=-1, keepdim=True)

    # cosine similarity as logits
    logit_scale = logit_scale.exp()
    logits_per_image = logit_scale * image_emb @ audio_emb.T
    logits_per_audio = logits_per_image.T
    # print(logits_per_image)
    # print(logits_per_audio)
    labels = torch.arange(image_emb.shape[0], device = image_emb.device)
    # print(labels)
    loss_image = torch.nn.functional.cross_entropy(logits_per_image, labels)
    loss_audio = torch.nn.functional.cross_entropy(logits_per_audio, labels)
    loss = (loss_image + loss_audio)/2
    return loss

Training:

@cached
@autocast
def call_audio_model(model, input):
    return model.forward_audio(input)

@cached
@autocast
def call_vision_model(model, input):
    return model.forward_visual(input)

if __name__ == "__main__":

    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 256
    n_epochs = np.inf #200
    batches_per_backward = 8
    save_every_n = 10
    max_gradient_norm = 1.0

    dataset = MMDataset(videos,
                        audio
                       )

    audio_encoder_cfg = {"name": "resnet18", "audio_shape": dataset[0]['audio'].shape, "num_microphones":4}
    visual_encoder_cfg = {"img_shape":dataset[0]['img'].shape}

    dataloader = torch.utils.data.DataLoader(dataset,batch_size = batch_size, shuffle = True)#, pin_memory=True, num_workers = 10)

    avsimclr = AVSimCLR(audio_encoder_cfg = audio_encoder_cfg, 
                     visual_encoder_cfg = visual_encoder_cfg,
                     projector_cfg = projector_cfg,
                     loss_cfg = loss_cfg).to(device)

    optimizer = Lars(avsimclr.parameters(), **optimizer_cfg)

    lr_scheduler_cfg["name"] = "ReduceLROnPlateau"

    scaler = torch.cuda.amp.GradScaler()

    cache_x = []
    cache_y = []
    closures_x = []
    closures_y = []

    losses = []

    global_loss = np.inf

    epoch = 1
    while True:
        losses = []
        with tqdm(enumerate(dataloader), desc = "step", leave = True, position = 0, unit = "batch") as batches:
            for step, sub_batch in batches:  
                audio = sub_batch['audio'].to(device)
                img = sub_batch['img'].to(device)

                rx, cx = call_audio_model(avsimclr, audio)
                ry, cy = call_vision_model(avsimclr, img)

                cache_x.append(rx)
                cache_y.append(ry)
                closures_x.append(cx)
                closures_y.append(cy)

                if (step + 1) % batches_per_backward == 0:
                    batches.set_description(f"Calculating backwards pass on step {step} of {len(dataloader)}!")

                    loss = cached_loss.cav_loss(cache_x, cache_y)
                    scaler.scale(loss).backward()

                    for f, r in zip(closures_x, cache_x):
                        f(r)
                    for f, r in zip(closures_y, cache_y):
                        f(r)

                    cache_x = []
                    cache_y = []
                    closures_x = []
                    closures_y = []

                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)

                    # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
                    torch.nn.utils.clip_grad_norm_(avsimclr.parameters(), max_gradient_norm)

                    # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
                    # although it still skips optimizer.step() if the gradients contain infs or NaNs.
                    scaler.step(optimizer)

                    # Updates the scale for next iteration.
                    scaler.update()

                    optimizer.zero_grad()

                    losses.append(loss.item())
                    batches.set_postfix_str(f"train_loss_step: {loss.item():.3f}")

                else:
                    batches.set_description(f"Epoch {epoch} step")

        train_loss_epoch = np.mean(losses)

        lr_scheduler.step(train_loss_epoch)

        epoch += 1

        torch.cuda.empty_cache() #-Call-empty_cache-whenever-we-del-somethingfrom-gpu
        gc.collect()

And the error i receive is

        File "gradcache/functional.py", line 22, in <module>
                reps_no_grad = func(*args, **kwargs)
        TypeError: __call__() takes 2 positional arguments but 3 were given
    And i believe its because its passing an autocast function into reps_no_grad bc removing `@autocast` fixes it
luyug commented 1 year ago

autocast is a class which means you have to do @autocast()

aaprasad commented 1 year ago

ah 😅 thanks so much @luyug!