Closed Yann-CV closed 4 months ago
To resolve the issue where newly created modules are not compatible with the to
instruction, ensure that all custom modules and components inherit from torch.nn.Module
. This ensures that the to
method can be applied to them correctly.
For example, here is how you can ensure that your custom components inherit from torch.nn.Module
:
from torch.nn import Module
class ImageTargeter(Module):
# Ensure this class inherits from torch.nn.Module
pass
class FilteredConv2d(Module):
# Ensure this class inherits from torch.nn.Module
pass
class TrainableAddPool2D(Module):
# Ensure this class inherits from torch.nn.Module
pass
class ScaledTanh(Module):
# Ensure this class inherits from torch.nn.Module
pass
class EuclideanDistanceRBF(Module):
# Ensure this class inherits from torch.nn.Module
pass
For instance, the FilteredConv2d
class should look like this:
import torch
from torch import Tensor
from torch.nn import Conv2d, ModuleList, Module
from torch.nn.common_types import _size_2_t
class FilteredConv2d(Module):
def __init__(
self,
in_channel_indices: list[list[int]],
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: str | _size_2_t = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
super().__init__()
self.conv2ds = ModuleList(
[
Conv2d(
len(channel_list),
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
for channel_list in in_channel_indices
]
)
self.in_channel_indices = in_channel_indices
def forward(self, x: Tensor) -> Tensor:
conv_result_list = []
for conv2d, in_channels in zip(
self.conv2ds, self.in_channel_indices, strict=True
):
conv_result_list.append(conv2d(x[:, in_channels]))
return torch.cat(conv_result_list, dim=-3)
Similarly, ensure that other custom components like Vae
also inherit from torch.nn.Module
and handle device placement correctly:
from torch import Tensor, device as torch_device, randn
from torch.nn import Module
class Vae(Module):
def __init__(
self,
image_width: int,
image_height: int,
hidden_size: int,
latent_size: int,
device: str | torch_device | int = "cpu",
) -> None:
super().__init__()
self.device = torch_device(device)
self.latent_size = latent_size
self.encoder = VaeEncoder(
image_width, image_height, hidden_size, latent_size
).to(self.device)
self.decoder = VaeDecoder(
image_width, image_height, hidden_size, latent_size
).to(self.device)
def forward(self, x: Tensor) -> Tensor:
encoded = self.encoder(x)
return self.decoder(encoded)
def loss(self, x: Tensor) -> Tensor:
reconstructed = self.forward(x)
vector_size = x.size(-1) * x.size(-2)
reconstruction_loss = (
binary_cross_entropy(
reconstructed,
x,
reduction="none",
)
.reshape(-1, vector_size)
.sum(dim=1)
)
mean, log_var = self.encoder.compute_mean_and_log_var(x)
kl_div = -0.5 * (1 + log_var - log_var.exp() - mean.pow(2)).sum(dim=1)
loss = kl_div.mean() + reconstruction_loss.mean()
return loss
def generate(self, n: int) -> Tensor:
return self.decoder(randn(n, self.decoder.latent_size, device=self.device))
def __str__(self) -> str:
return (
f"VAE_image_width_{self.encoder.image_width}_image_height_{self.encoder.image_height}"
f"_hidden_size_{self.encoder.hidden_size}_latent_size_{self.encoder.latent_size}"
)
By ensuring all custom components inherit from torch.nn.Module
, the to
method will work correctly, and all layers will be moved to the specified device.
To continue the conversation, mention @dosu.
fixed by #24
FOr some new created modules, the usage of the
to
method is not working accordingly. it means that sometimes some layer can be in the wrong device.