KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.94k stars 1.38k forks source link

(BUG) Expected input and other to have the same dtype, but got input's dtype Float and other's dtype Double #238

Closed soulcarus closed 3 months ago

soulcarus commented 5 months ago

Title: Error training time series forecasting model for cement mill variables

Description:

Hi, I'm trying to train a model for predicting values using a time series forecasting of cement mill variables.

I am encountering an error when running the following code. Below is the code snippet and the error message I'm receiving.

Codes:

train_labels = torch.tensor(train_raw[block_r].values).unsqueeze(-1)
test_labels = torch.tensor(test_raw[block_r].values).unsqueeze(-1)

train_input = torch.tensor(np.array([train_raw[in_col].values for in_col in block.columns])).transpose(0, 1)
test_input = torch.tensor(np.array([test_raw[in_col].values for in_col in block.columns])).transpose(0, 1)

>  #i've already converted to .double() because of this error and tried to run again but the result is the same

train_labels = train_labels.double()
test_labels = test_labels.double()
train_input = train_input.double()
test_input = test_input.double()

prepped_dataset = {
    'train_input': train_input,
    'train_label': train_labels,
    'test_input': test_input,
    'test_label': test_labels
}
date_only_dataset = {
    'train_input': train_input[:, :2],
    'train_label': train_labels,
    'test_input': test_input[:, :2],
    'test_label': test_labels
}

Creating the models

import copy

model_date_only = KAN([2,3,3,1], grid=10, k=3, seed=0)
model_width = [len(block.columns),3,3,1]
model_coarse = KAN(width=copy.copy(model_width), grid=5, k=3, seed=0)
model_fine = KAN(width=copy.copy(model_width), grid=10, k=3, seed=0)

Running the model_date_only model:

model_date_only(train_input[:,:2])
model_date_only.plot(beta=100)

import matplotlib.pyplot as plt
import copy

HIDDEN_KEYS = ['reg']
STEPS = 3

loss_data = model_date_only.train(date_only_dataset, opt="LBFGS", steps=STEPS, lamb=0.0001, lamb_entropy=10.)

x = np.arange(STEPS)
for line_key in loss_data:
    if line_key in HIDDEN_KEYS:
        continue
    plt.plot(x, loss_data[line_key], label=str(line_key).replace('_',' ').title())
plt.legend()
plt.show()

Running the model_coarse and model_fine models:

model_coarse(train_input)
model_coarse.plot(beta=100)

import matplotlib.pyplot as plt
import copy

HIDDEN_KEYS = ['reg']
STEPS = 1

loss_data_coarse = model_coarse.train(prepped_dataset, opt="LBFGS", steps=STEPS, lamb=0.0001, lamb_entropy=10.)
model_fine.initialize_from_another_model(model_coarse, train_input)
loss_data_fine = model_fine.train(prepped_dataset, opt="LBFGS", steps=STEPS, lamb=0.0001, lamb_entropy=10.)

loss_data = {}
for loss_data_subset in [loss_data_coarse,loss_data_fine]:
    for loss_key in loss_data_subset:
        if loss_key not in loss_data:
            loss_data[loss_key] = []
        loss_data[loss_key].extend(loss_data_subset[loss_key])

x = np.arange(STEPS*2)
for line_key in loss_data:
    if line_key in HIDDEN_KEYS:
        continue
    plt.plot(x, loss_data[line_key], label=str(line_key).replace('_',' ').title())
plt.legend()
plt.show()

Error Message:

On the previous cell, specificaly this line:

model_fine.initialize_from_another_model(model_coarse, train_input)

i got the following error:

torch.linalg.lstsq: Expected input and other to have the same dtype, but got input's dtype Float and other's dtype Double

why is it happening? everything before that worked fine, it got like 0 errors, but when i get to this part, it breaks.

Environment:

i'm running on jupyter notebook LTS

soulcarus commented 5 months ago

First Plot

image

Loss Plot (model_date_only)

image

Plot after "loss training" image

Plot (model_coarse) image

RAW DATAFRAME ( it's later spplited in two parts 80% and 20%, and i needed to hide the other columns )

image

prepped dataset

image

train input

image

test input

image image

train labels

image

test labels

image

bug:

image

KindXiaoming commented 3 months ago

please try the most up-to-date version and see if the problem persists.

soulcarus commented 3 months ago

Thanks! :D