OML-Team / open-metric-learning

Metric learning and retrieval pipelines, models and zoo.
https://open-metric-learning.readthedocs.io/en/latest/index.html
Apache License 2.0
895 stars 61 forks source link

Does arcface support train on gpu? #620

Closed olegkorshunov closed 4 months ago

olegkorshunov commented 4 months ago
EPOCH = 5
device = "cuda"
model = model.to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = ArcFaceLoss(in_features=emb_size, num_classes=labels_amount)
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=3, n_instances=5)

def training():
    model.train()
    for _ in range(EPOCH):
        for batch in tqdm(DataLoader(train_dataset, batch_sampler=sampler)):
            embeddings = model(batch["input_tensors"].to(device))
            loss = criterion(embeddings, batch["labels"].to(device))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(criterion.last_logs)

error log

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], [line 1](vscode-notebook-cell:?execution_count=13&line=1)
----> [1](vscode-notebook-cell:?execution_count=13&line=1) training()

Cell In[12], [line 14](vscode-notebook-cell:?execution_count=12&line=14)
     [12](vscode-notebook-cell:?execution_count=12&line=12) for batch in tqdm(DataLoader(train_dataset, batch_sampler=sampler)):
     [13](vscode-notebook-cell:?execution_count=12&line=13)     embeddings = model(batch["input_tensors"].to(device))
---> [14](vscode-notebook-cell:?execution_count=12&line=14)     loss = criterion(embeddings, batch["labels"].to(device))
     [15](vscode-notebook-cell:?execution_count=12&line=15)     loss.backward()
     [16](vscode-notebook-cell:?execution_count=12&line=16)     optimizer.step()

File ~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   [1522](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/torch/nn/modules/module.py:1530)     result = None

File ~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:78, in ArcFaceLoss.forward(self, x, y)
     [75](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:75) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
     [76](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:76)     assert torch.all(y < self.num_classes), "You should provide labels between 0 and num_classes - 1."
---> [78](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:78)     cos = self.fc(x)
     [80](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:80)     self._log_accuracy_on_batch(cos, y)
     [82](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:82)     sin = torch.sqrt(1.0 - torch.pow(cos, 2))

File ~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:68, in ArcFaceLoss.fc(self, x)
     [67](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:67) def fc(self, x: torch.Tensor) -> torch.Tensor:
---> [68](https://vscode-remote+wsl-002bubuntu-002d22-002e04.vscode-resource.vscode-cdn.net/mnt/e/projects/face_recognition/jupyter/~/miniconda3/envs/main/lib/python3.10/site-packages/oml/losses/arcface.py:68)     return F.linear(F.normalize(x, p=2), F.normalize(self.weight, p=2))

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)
AlekseySh commented 4 months ago

hey, @olegkorshunov and thank you for using OML!

Yes, it does work with ArcFace on GPU. You simply forgot to move your criterion to the cuda:0, which is needed because ArcFace has its classification head.

Here is a fully working example (the one from Readme, but with ArcFace):

from torch.optim import Adam
from torch.utils.data import DataLoader

from oml import datasets as d
from oml.inference import inference
from oml.losses import ArcFaceLoss
from oml.metrics import calc_retrieval_metrics_rr
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.retrieval import RetrievalResults, AdaptiveThresholding
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset

device = "cuda:0"
model = ViTExtractor.from_pretrained("vits16_dino").to(device).train()
transform, _ = get_transforms_for_pretrained("vits16_dino")

df_train, df_val = get_mock_images_dataset(global_paths=True)
train = d.ImageLabeledDataset(df_train, transform=transform)
val = d.ImageQueryGalleryLabeledDataset(df_val, transform=transform)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = ArcFaceLoss(in_features=384, num_classes=4).to(device)  # <<<<<<<<<<<<<<<<<< HERE
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)

def training():
    for batch in DataLoader(train, batch_sampler=sampler):
        embeddings = model(batch["input_tensors"].to(device))
        loss = criterion(embeddings, batch["labels"].to(device))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(criterion.last_logs)

def validation():
    embeddings = inference(model, val, batch_size=4, num_workers=0)
    rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
    rr = AdaptiveThresholding(n_std=2).process(rr)
    rr.visualize(query_ids=[2, 1], dataset=val, show=True)
    print(calc_retrieval_metrics_rr(rr, map_top_k=(3,), cmc_top_k=(1,)))

training()
validation()
olegkorshunov commented 4 months ago

Thx!

olegkorshunov commented 4 months ago

hey, @olegkorshunov and thank you for using OML!

Yes, it does work with ArcFace on GPU. You simply forgot to move your criterion to the cuda:0, which is needed because ArcFace has its classification head.

Here is a fully working example (the one from Readme, but with ArcFace):

from torch.optim import Adam
from torch.utils.data import DataLoader

from oml import datasets as d
from oml.inference import inference
from oml.losses import ArcFaceLoss
from oml.metrics import calc_retrieval_metrics_rr
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.retrieval import RetrievalResults, AdaptiveThresholding
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset

device = "cuda:0"
model = ViTExtractor.from_pretrained("vits16_dino").to(device).train()
transform, _ = get_transforms_for_pretrained("vits16_dino")

df_train, df_val = get_mock_images_dataset(global_paths=True)
train = d.ImageLabeledDataset(df_train, transform=transform)
val = d.ImageQueryGalleryLabeledDataset(df_val, transform=transform)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = ArcFaceLoss(in_features=384, num_classes=4).to(device)  # <<<<<<<<<<<<<<<<<< HERE
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)

def training():
    for batch in DataLoader(train, batch_sampler=sampler):
        embeddings = model(batch["input_tensors"].to(device))
        loss = criterion(embeddings, batch["labels"].to(device))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(criterion.last_logs)

def validation():
    embeddings = inference(model, val, batch_size=4, num_workers=0)
    rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
    rr = AdaptiveThresholding(n_std=2).process(rr)
    rr.visualize(query_ids=[2, 1], dataset=val, show=True)
    print(calc_retrieval_metrics_rr(rr, map_top_k=(3,), cmc_top_k=(1,)))

training()
validation()

I try to use your code example and is it correct that criterion.last_logs become zero?

AlekseySh commented 4 months ago

@olegkorshunov yep, it's fine, it's all about hyperparameters and training time. I set lr=1e-5 and 5 epochs and got accuracy==1:

...
optimizer = Adam(model.parameters(), lr=1e-5)
...

for _ in range(5):
    training()

validation()