Hello there! Thanks for the introducing RINorm class but it seems it might be better during inference.
Bug description
When debugging NBEATS with use_reversible_instance_norm=True I found strange results (the first prediction point after .predict method seemed okay, but from the 2nd onward, it didn't) and the reason for them was io_processor wrapper, where we access inputs and normalize it using rin, but we do NOT clone it, which leads to normalizing poor target again and again during .predict method on each iteration... (due to the tensor being a reference)
To Reproduce
You could use the following code to run (darts == 0.27.1). I just overridden some methods for the console logging, so the logic to produce outputs it quite the same.
from functools import wraps
import torch
from darts.models import NBEATSModel
from darts.models.forecasting.nbeats import _NBEATSModule
from darts.timeseries import TimeSeries
from typing import Tuple
class Counter:
def __init__(self):
self.value = 0
def incr(self):
self.value += 1
def reset(self):
self.value = 0
LEAK_COUNTER = Counter()
SAFE_COUNTER = Counter()
def io_processor(forward):
@wraps(forward)
def forward_wrapper(self, *args, **kwargs):
if not self.use_reversible_instance_norm:
return forward(self, *args, **kwargs)
# x is input batch tuple which by definition has the past features in the first element starting with the
# first n target features
# x: Tuple = args[0][0]
# assuming `args[0][0]` is torch.Tensor we could clone it to prevent memory leak and target over-normalization
x: Tuple = args[0][0].clone()
# apply reversible instance normalization
x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets])
# run the forward pass
out = forward(self, *((x, *args[0][1:]), *args[1:]), **kwargs)
# inverse transform target output back to original scale; by definition the first output
if isinstance(out, tuple):
return self.rin.inverse(out[0]), *out[1:]
else:
return self.rin.inverse(out)
return forward_wrapper
def io_processor_with_leak(forward):
@wraps(forward)
def forward_wrapper(self, *args, **kwargs):
if not self.use_reversible_instance_norm:
return forward(self, *args, **kwargs)
# x is input batch tuple which by definition has the past features in the first element starting with the
# first n target features
x: Tuple = args[0][0]
# apply reversible instance normalization
x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets])
# run the forward pass
out = forward(self, *((x, *args[0][1:]), *args[1:]), **kwargs)
# inverse transform target output back to original scale; by definition the first output
if isinstance(out, tuple):
return self.rin.inverse(out[0]), *out[1:]
else:
return self.rin.inverse(out)
return forward_wrapper
def run_nbeats_forward(self, x_in):
x, _ = x_in
# if x1, x2,... y1, y2... is one multivariate ts containing x and y, and a1, a2... one covariate ts
# we reshape into x1, y1, a1, x2, y2, a2... etc
x = torch.reshape(x, (x.shape[0], self.input_chunk_length_multi, 1))
# squeeze last dimension (because model is univariate)
x = x.squeeze(dim=2)
# One vector of length target_length per parameter in the distribution
y = torch.zeros(
x.shape[0],
self.target_length,
self.nr_params,
device=x.device,
dtype=x.dtype,
)
for stack in self.stacks_list:
# compute stack output
stack_residual, stack_forecast = stack(x)
# add stack forecast to final output
y = y + stack_forecast
# set current stack residual as input for next stack
x = stack_residual
# In multivariate case, we get a result [x1_param1, x1_param2], [y1_param1, y1_param2], [x2..], [y2..], ...
# We want to reshape to original format. We also get rid of the covariates and keep only the target dimensions.
# The covariates are by construction added as extra time series on the right side. So we need to get rid of this
# right output (keeping only :self.output_dim).
y = y.view(
y.shape[0], self.output_chunk_length, self.input_dim, self.nr_params
)[:, :, : self.output_dim, :]
return y
def run_produce_predict_output(self, x, counter: Counter):
if self.likelihood:
output = self(x)
if self.predict_likelihood_parameters:
return self.likelihood.predict_likelihood_parameters(output)
else:
return self.likelihood.sample(output)
else:
step = counter.value
if step == 0: # prevent overlap with predicting pbar
print(f"\nproducing predictions of {self.__class__.__name__}...")
print(f"step {step}: before forward = {x[0].view(-1)}")
out = self(x).squeeze(dim=-1)
print(f"step {step}: after forward = {x[0].view(-1)}")
counter.incr()
return out
class _NBEATSModuleNoLeak(_NBEATSModule):
@io_processor
def forward(self, x_in: Tuple):
return run_nbeats_forward(self, x_in)
def _produce_predict_output(self, x: Tuple) -> torch.Tensor:
return run_produce_predict_output(self, x, SAFE_COUNTER)
class _NBEATSModuleWithLeak(_NBEATSModule):
@io_processor_with_leak
def forward(self, x_in: Tuple):
return run_nbeats_forward(self, x_in)
def _produce_predict_output(self, x: Tuple) -> torch.Tensor:
return run_produce_predict_output(self, x, LEAK_COUNTER)
class NBEATSModelInitial(NBEATSModel):
def __init__(self, input_chunk_length: int, output_chunk_length: int, **kwargs):
super().__init__(input_chunk_length, output_chunk_length, **kwargs)
def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
train_sample[1].shape[1] if train_sample[1] is not None else 0
)
output_dim = train_sample[-1].shape[1]
nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters
return _NBEATSModuleWithLeak(
input_dim=input_dim,
output_dim=output_dim,
nr_params=nr_params,
generic_architecture=self.generic_architecture,
num_stacks=self.num_stacks,
num_blocks=self.num_blocks,
num_layers=self.num_layers,
layer_widths=self.layer_widths,
expansion_coefficient_dim=self.expansion_coefficient_dim,
trend_polynomial_degree=self.trend_polynomial_degree,
batch_norm=self.batch_norm,
dropout=self.dropout,
activation=self.activation,
**self.pl_module_params,
)
class NBEATSModelNoLeak(NBEATSModel):
def __init__(self, input_chunk_length: int, output_chunk_length: int, **kwargs):
super().__init__(input_chunk_length, output_chunk_length, **kwargs)
def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
train_sample[1].shape[1] if train_sample[1] is not None else 0
)
output_dim = train_sample[-1].shape[1]
nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters
return _NBEATSModuleNoLeak(
input_dim=input_dim,
output_dim=output_dim,
nr_params=nr_params,
generic_architecture=self.generic_architecture,
num_stacks=self.num_stacks,
num_blocks=self.num_blocks,
num_layers=self.num_layers,
layer_widths=self.layer_widths,
expansion_coefficient_dim=self.expansion_coefficient_dim,
trend_polynomial_degree=self.trend_polynomial_degree,
batch_norm=self.batch_norm,
dropout=self.dropout,
activation=self.activation,
**self.pl_module_params,
)
def run_model(model: NBEATSModel):
print('*' * 40)
print(f"running {model.__class__.__name__}")
sample_data_values = torch.linspace(0.1, 1.9, 19)
# sample_data_indices = torch.arange(0, 9, 1)
darts_series = TimeSeries.from_values(sample_data_values.cpu().numpy())
x_in = (sample_data_values, None)
model.fit(series=[darts_series])
input_slice = sample_data_values[:model.input_chunk_length]
input_slice_copy = input_slice.clone()
input_slice = input_slice.unsqueeze(0).unsqueeze(-1)
# call forward()
_ = model.model((input_slice, None))
print((f'\ninput before forward: {input_slice_copy}\n'
f'after forward: {input_slice.view(*input_slice_copy.shape)}\n'))
# call predict:
input_darts_series = TimeSeries.from_values(input_slice_copy.cpu().numpy())
LEAK_COUNTER.reset()
pred = model.predict(series=input_darts_series, n=2)
pred_data = pred._xa.data.reshape(-1)
print(pred_data)
return
def main():
m_kwargs = dict(input_chunk_length=3, output_chunk_length=1, use_reversible_instance_norm=True, n_epochs=10)
initial_one = NBEATSModelInitial(**m_kwargs)
no_leak_model = NBEATSModelNoLeak(**m_kwargs)
run_model(initial_one)
print("\n\n")
run_model(no_leak_model)
if __name__ == '__main__':
main()
Hello there! Thanks for the introducing
RINorm
class but it seems it might be better during inference.Bug description When debugging NBEATS with
use_reversible_instance_norm=True
I found strange results (the first prediction point after .predict method seemed okay, but from the 2nd onward, it didn't) and the reason for them wasio_processor
wrapper, where we access inputs and normalize it usingrin
, but we do NOT clone it, which leads to normalizing poor target again and again during.predict
method on each iteration... (due to the tensor being a reference)To Reproduce You could use the following code to run (darts == 0.27.1). I just overridden some methods for the console logging, so the logic to produce outputs it quite the same.
Expected behavior
System (please complete the following information):
Additional context I'll soon create a pull request, so please review it :)