somepago / saint

The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training
Apache License 2.0
402 stars 63 forks source link

reps = model_saint.transformer(x_categ_enc, x_cont_enc)==nan #2

Closed abebe9849 closed 3 years ago

abebe9849 commented 3 years ago

I applied this great model to regression, but the value is nan in the model.transformer part.

class RowColTransformer(nn.Module):
~~~~~~~~~
    def forward(self, x, x_cont=None, mask = None):
        if x_cont is not None:
            x = torch.cat((x,x_cont),dim=1)
        _, n, _ = x.shape
        print("TRANFOERMR")
        if self.style == 'colrow':
            for attn1, ff1, attn2, ff2 in self.layers: 
                x = attn1(x)##here x==nan

Did this happen during implementation? If anyone has used it for their own data, please let me know.

these are hyper params

model_saint = SAINT(
    categories = tuple(cat_dims.values()),#len(cat_dims)==2
    num_continuous = len(numerical_features)+1,         
    dim =128,                           
    dim_out = 1,                       
    depth = 6,                       
    heads = 8,                         
    attn_dropout = 0.1,             
    ff_dropout = 0.1,                  
    mlp_hidden_mults = (4, 2),       
    continuous_mean_std = None, 
    cont_embeddings = "MLP",
    attentiontype = 'col',
    final_mlp_style = 'sep',
    y_dim = 1
    )

optim:AdamW(model_saint.parameters(), lr=1e-3,weight_decay=5e-5)
BATCH_size=256
somepago commented 3 years ago

Hi,

Did you ensure, the x passed on the attention layer [i.e. x passed to attn1(x)] does not have any NaNs?

abebe9849 commented 3 years ago

I took care not to include nan in the input and it was executed without any problem. Thank you very much.