sicara / easy-few-shot-learning

Ready-to-use code and tutorial notebooks to boost your way into few-shot learning for image classification.
MIT License
1.03k stars 141 forks source link

About relation networks #113

Closed HowCuteIsBee2002018 closed 12 months ago

HowCuteIsBee2002018 commented 1 year ago

Problem in classical training i try to run relation network but it says illegal backbone used. what do i need to change due to the output must be feature maps ?

Considered solutions What have you tried but didn't work?

How can we help can u guide me on this like which code to edit?

ebennequin commented 1 year ago

First answers to this issue on #24

Please be as specific as you can when describing your error:

In the screenshot you included in issue #24 we can't see whether use_fc is set to True or not.

HowCuteIsBee2002018 commented 1 year ago

As u said to remove the pooling layer, i have set it to false like in the code below

model = resnet12(
    #use_fc=True,
    num_classes=len(set(train_set.get_labels())),
    use_pooling=False,
).to(DEVICE)

i have follow everything just changing to prototypical network to relation network but when i try to run the evaluation it gives this error:

RuntimeError                              Traceback (most recent call last)
Cell In[16], line 12
     10 for epoch in range(n_epochs):
     11     print(f"Epoch {epoch}")
---> 12     average_loss = training_epoch(model, train_loader, train_optimizer)
     14     # Store epoch and loss values for plotting
     15     epochs.append(epoch)

Cell In[15], line 8, in training_epoch(model_, data_loader, optimizer)
      5 for images, labels in tqdm_train:
      6     optimizer.zero_grad()
----> 8     loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))
      9     loss.backward()
     10     optimizer.step()

File ~/Desktop/conda_jhteoh/envs/jhteoh/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Desktop/conda_jhteoh/envs/jhteoh/lib/python3.11/site-packages/torch/nn/modules/loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
   1173 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1174     return F.cross_entropy(input, target, weight=self.weight,
   1175                            ignore_index=self.ignore_index, reduction=self.reduction,
   1176                            label_smoothing=self.label_smoothing)

File ~/Desktop/conda_jhteoh/envs/jhteoh/lib/python3.11/site-packages/torch/nn/functional.py:3029, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3027 if size_average is not None or reduce is not None:
   3028     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3029 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [128]
ebennequin commented 1 year ago

During classical training, you expect your model to output a class prediction, which must be of shape (batch_size, number_of_classes_in_the_train_set). So you need the pooling and the fully connected layer in your model.

You are only using Relation Networks (and thus the additional module on top of your backbone) during evaluation (and validation). At this time, you need to remove pooling as well as the fully connected layer.

HowCuteIsBee2002018 commented 1 year ago

do i edit the code in the resnet module here for the forward function? because i tried removing but still could not work. sorry to bother u, i am new to deep learning but is there any step to change the code by removing both pooling and fcl?

def forward(self, x: Tensor) -> Tensor: """ Forward pass through the ResNet. Args: x: input tensor of shape (batch_size, image_shape) Returns: x: output tensor of shape (batch_size, num_classes) if self.use_fc is True, otherwise of shape (batch_size, feature_shape) """ x = self.layer4( self.layer3(self.layer2(self.layer1(self.relu(self.bn1(self.conv1(x)))))) )

    if self.use_pooling:
        x = torch.flatten(
            self.avgpool(x),
            1,
        )

        if self.use_fc:
            return self.fc(x)

    else:
        if self.use_fc:
            raise ValueError(
                "You can't use the fully connected layer without pooling features."
            )

    return x
ebennequin commented 1 year ago

As you can see in the forward() method of the ResNet, both pooling and the fully connected layer are optional. Simply initialize the resnet12 with use_pooling=False and use_fc=False and you should be all set.

HowCuteIsBee2002018 commented 1 year ago

i have tried setting pooling and fc to false but it wont work, the error still out "RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [128]" like the above error. Is it because of the code here?


from easyfsl.utils import evaluate

best_state = model.state_dict()
best_validation_accuracy = 0.0
validation_frequency = 10
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(model, train_loader, train_optimizer)

    if epoch % validation_frequency == validation_frequency - 1:

        # We use this very convenient method from EasyFSL's ResNet to specify
        # that the model shouldn't use its last fully connected layer during validation.
        #model.set_use_fc(False)
        validation_accuracy = evaluate(
            few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
        )
        #model.set_use_fc(True)

        if validation_accuracy > best_validation_accuracy:
            best_validation_accuracy = validation_accuracy
            best_state = model.state_dict()
            print("Ding ding ding! We found a new best model!")

        tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    tb_writer.add_scalar("Train/loss", average_loss, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()
ebennequin commented 1 year ago

As I wrote before, you must enable the fully-connected and pooling before training, and disable it before validation.


for epoch in range(n_epochs):
    model.set_use_fc(True)
    model.use_pooling = True    # No explicit method here so we change the attribute directly
    average_loss = training_epoch(model, train_loader, train_optimizer)

    if epoch % validation_frequency == validation_frequency - 1:

        model.set_use_fc(False)
        model.use_pooling = False
        validation_accuracy = evaluate(
            few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
        )

        [...]