Closed HowCuteIsBee2002018 closed 12 months ago
First answers to this issue on #24
Please be as specific as you can when describing your error:
In the screenshot you included in issue #24 we can't see whether use_fc
is set to True or not.
As u said to remove the pooling layer, i have set it to false like in the code below
model = resnet12(
#use_fc=True,
num_classes=len(set(train_set.get_labels())),
use_pooling=False,
).to(DEVICE)
i have follow everything just changing to prototypical network to relation network but when i try to run the evaluation it gives this error:
RuntimeError Traceback (most recent call last)
Cell In[16], line 12
10 for epoch in range(n_epochs):
11 print(f"Epoch {epoch}")
---> 12 average_loss = training_epoch(model, train_loader, train_optimizer)
14 # Store epoch and loss values for plotting
15 epochs.append(epoch)
Cell In[15], line 8, in training_epoch(model_, data_loader, optimizer)
5 for images, labels in tqdm_train:
6 optimizer.zero_grad()
----> 8 loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))
9 loss.backward()
10 optimizer.step()
File ~/Desktop/conda_jhteoh/envs/jhteoh/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Desktop/conda_jhteoh/envs/jhteoh/lib/python3.11/site-packages/torch/nn/modules/loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
1173 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1174 return F.cross_entropy(input, target, weight=self.weight,
1175 ignore_index=self.ignore_index, reduction=self.reduction,
1176 label_smoothing=self.label_smoothing)
File ~/Desktop/conda_jhteoh/envs/jhteoh/lib/python3.11/site-packages/torch/nn/functional.py:3029, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3027 if size_average is not None or reduce is not None:
3028 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3029 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [128]
During classical training, you expect your model to output a class prediction, which must be of shape (batch_size, number_of_classes_in_the_train_set)
. So you need the pooling and the fully connected layer in your model.
You are only using Relation Networks (and thus the additional module on top of your backbone) during evaluation (and validation). At this time, you need to remove pooling as well as the fully connected layer.
do i edit the code in the resnet module here for the forward function? because i tried removing but still could not work. sorry to bother u, i am new to deep learning but is there any step to change the code by removing both pooling and fcl?
def forward(self, x: Tensor) -> Tensor: """ Forward pass through the ResNet. Args: x: input tensor of shape (batch_size, image_shape) Returns: x: output tensor of shape (batch_size, num_classes) if self.use_fc is True, otherwise of shape (batch_size, feature_shape) """ x = self.layer4( self.layer3(self.layer2(self.layer1(self.relu(self.bn1(self.conv1(x)))))) )
if self.use_pooling:
x = torch.flatten(
self.avgpool(x),
1,
)
if self.use_fc:
return self.fc(x)
else:
if self.use_fc:
raise ValueError(
"You can't use the fully connected layer without pooling features."
)
return x
As you can see in the forward()
method of the ResNet, both pooling and the fully connected layer are optional. Simply initialize the resnet12
with use_pooling=False
and use_fc=False
and you should be all set.
i have tried setting pooling and fc to false but it wont work, the error still out "RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [128]" like the above error. Is it because of the code here?
from easyfsl.utils import evaluate
best_state = model.state_dict()
best_validation_accuracy = 0.0
validation_frequency = 10
for epoch in range(n_epochs):
print(f"Epoch {epoch}")
average_loss = training_epoch(model, train_loader, train_optimizer)
if epoch % validation_frequency == validation_frequency - 1:
# We use this very convenient method from EasyFSL's ResNet to specify
# that the model shouldn't use its last fully connected layer during validation.
#model.set_use_fc(False)
validation_accuracy = evaluate(
few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
)
#model.set_use_fc(True)
if validation_accuracy > best_validation_accuracy:
best_validation_accuracy = validation_accuracy
best_state = model.state_dict()
print("Ding ding ding! We found a new best model!")
tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)
tb_writer.add_scalar("Train/loss", average_loss, epoch)
# Warn the scheduler that we did an epoch
# so it knows when to decrease the learning rate
train_scheduler.step()
As I wrote before, you must enable the fully-connected and pooling before training, and disable it before validation.
for epoch in range(n_epochs):
model.set_use_fc(True)
model.use_pooling = True # No explicit method here so we change the attribute directly
average_loss = training_epoch(model, train_loader, train_optimizer)
if epoch % validation_frequency == validation_frequency - 1:
model.set_use_fc(False)
model.use_pooling = False
validation_accuracy = evaluate(
few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
)
[...]
Problem in classical training i try to run relation network but it says illegal backbone used. what do i need to change due to the output must be feature maps ?
Considered solutions What have you tried but didn't work?
How can we help can u guide me on this like which code to edit?