Closed Manasa-3 closed 3 months ago
Can you post your code at line 65 and a bit after?
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=custom_collate_fn, pin_memory=True) print(f"Number of workers: {train_loader.num_workers}") test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=custom_collate_fn, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=custom_collate_fn, pin_memory=True)
model = smp.Unet( 'resnet34', encoder_weights=None, input_shape=image.shape )
criterion = nn.MSELoss() optimizer = Adam(model.parameters(), lr=0.001)
num_epochs = 10 # Define the number of epochs
scaler = GradScaler()
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch_idx, (images, masks) in enumerate(train_loader):
images, masks = images.to(device), masks.to(device)
images = images.unsqueeze(0)
optimizer.zero_grad()
with autocast():
outputs = model(images)
loss = criterion(outputs, masks)
# Scale the loss and call backward
scaler.scale(loss).backward()
# Unscales the gradients and calls optimizer.step
scaler.step(optimizer)
# Updates the scale for next iteration
scaler.update()
running_loss += loss.item()
if True: # Always true, so this block runs for every batch
print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss/(batch_idx+1):.4f}")
running_loss = 0.0
gc.collect()
This is the code from line 65...
What's value of image.shape
?
From example:
# If you need to specify non-standard input shape
model5 = sm.Unet(
'resnet50',
input_shape=(96, 128, 128, 6),
encoder_weights=None
)
It must be 4 numbers: (W, H, D, C)
okay thankyou, I'll check it out.
IndexError Traceback (most recent call last) in <cell line: 65>()
63
64 # Initialize Model, Loss Function, and Optimizer
---> 65 model = smp.Unet(
66 'resnet34',
67 encoder_weights=None,
7 frames /usr/local/lib/python3.10/dist-packages/keras/src/layers/normalization/batch_normalization.py in build(self, input_shape) 168 169 def build(self, input_shape): --> 170 shape = (input_shape[self.axis],) 171 if self.scale: 172 self.gamma = self.add_weight(
IndexError: tuple index out of range. Please help me out for this error, My image is of shape (256,128,128)