Open mw66 opened 1 year ago
https://github.com/cezannec/capsule_net_pytorch/issues/4#issuecomment-629573681
sparselearning/core_kernel.py
change to:
- new_mask.data.view(-1)[:num_remain] = 1.0
+ new_mask.data.contiguous().view(-1)[:num_remain] = 1.0
worked.
May I check if you managed to reproduce the paper's results? Thanks.
May I check if you managed to reproduce the paper's results? Thanks.
Hi, I run the code with the daily_sport
dataset, but I got a very bad performance (Train Accuracy only 0.04=200/4560). How about you?
Hi @ houchenyu @mw66 pls help am facing issue def train(epoch): tr_loss, tr_accuracy = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0 tr_preds, tr_labels = [], []
model.train()
for idx, batch in enumerate(training_loader):
ids = batch['ids']
ids = ids.transpose(0, 1)
ids=ids.to(device)
mask = batch['mask']
mask=mask.transpose(0, 1)
mask=mask.to(device)
targets = batch['targets']
targets=targets.transpose(0, 1)
targets=targets.to(device)
optimizer.zero_grad()
print(ids.shape)
print(mask.shape)
outputs= model(input_ids=ids, attention_mask=mask, labels=targets)
loss, tr_logits = outputs.loss, outputs.logits
tr_loss += loss.item()
nb_tr_steps += 1
nb_tr_examples += targets.size(0)
if idx % 100==0:
loss_step = tr_loss/nb_tr_steps
print(f"Training loss per 100 training steps: {loss_step}")
# compute training accuracy
flattened_targets = targets.view(-1) # shape (batch_size * seq_len,)
active_logits = tr_logits.view(-1, model.num_labels) # shape (batch_size * seq_len, num_labels)
flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)
# now, use mask to determine where we should compare predictions with targets (includes [CLS] and [SEP] token predictions)
active_accuracy = mask.view(-1) == 1 # active accuracy is also of shape (batch_size * seq_len,)
targets = torch.masked_select(flattened_targets, active_accuracy)
predictions = torch.masked_select(flattened_predictions, active_accuracy)
tr_preds.extend(predictions)
tr_labels.extend(targets)
tmp_tr_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())
tr_accuracy += tmp_tr_accuracy
# gradient clipping
torch.nn.utils.clip_grad_norm_(
parameters=model.parameters(), max_norm=MAX_GRAD_NORM
)
# backward pass
loss.backward()
optimizer.step()
epoch_loss = tr_loss / nb_tr_steps
tr_accuracy = tr_accuracy / nb_tr_steps
print(f"Training loss epoch: {epoch_loss}")
print(f"Training accuracy epoch: {tr_accuracy}")
RuntimeError Traceback (most recent call last)
3 frames /usr/local/lib/python3.10/dist-packages/transformers/models/roberta/modeling_roberta.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict) 1230 elif self.config.problem_type == "single_label_classification": 1231 loss_fct = CrossEntropyLoss() -> 1232 loss = loss_fct(logits.contiguous().view(-1, self.num_labels), labels.contiguous().view(-1)) 1233 elif self.config.problem_type == "multi_label_classification": 1234 loss_fct = BCEWithLogitsLoss()
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
python trainer_DSN.py --sparse True --density 0.2 --sparse_init remain_sort --fix False --growth random --depth 4 --ch_size 47 --c_size 3 --k_size 39 \ --root /data/UCR_TS_Archive_2015
Traceback (most recent call last): File "/home/./project/contrib/DSN/trainer_DSN.py", line 193, in
main(args)
File "/home/./project/contrib/DSN/trainer_DSN.py", line 116, in main
mask.add_module(model, sparse_init=args.sparse_init, density=args.density)
File "/home/./project/contrib/DSN/sparselearning/core_kernel.py", line 241, in add_module
self.init(mode=sparse_init, density=density)
File "/home/./project/contrib/DSN/sparselearning/core_kernel.py", line 175, in init
new_mask.data.view(-1)[:num_remain] = 1.0
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.