MLResearchAtOSRAM / tmm_fast

tmm_fast is a lightweight package to speed up optical planar multilayer thin-film device computation. Developed by Alexander Luce (@Nerrror) in cooperation with Heribert Wankerl (@HarryTheBird).
MIT License
53 stars 22 forks source link

BUG: Unexpected squeezing of Theta and lambda_vacuum inputs in coh_vec_tmm_disp_mstack #7

Closed kadykov closed 1 year ago

kadykov commented 1 year ago

I have tried to compare the output of the coh_vec_tmm_disp_mstack function with the original tmm package. For this I have reduced the Theta and lambda_vacuum vectors to a single element as following:

import numpy as np
from tmm_fast.vectorized_tmm_dispersive_multistack import coh_vec_tmm_disp_mstack as tmm

# wl = np.asarray([400., 500.,]) * 1e-9 # This works well
wl = np.asarray([400.,]) * 1e-9 # AssertionError: N and T are not of same shape, as they are of dimensions 3 and 2
# theta = np.asarray([0., 45.,]) # This also works well
theta = np.asarray([0.,]) # IndexError: tuple index out of range
mode = 'T'
num_layers = 4
num_stacks = 128

refractive_index = np.ones([num_stacks, num_layers, wl.shape[0]])
thickness = np.ones([num_stacks, num_layers]) * 100

tmm(
    pol="s",
    N=refractive_index,
    T=thickness,
    Theta=theta,
    lambda_vacuum=wl,
)

However, I have faced the following error in the case of monochromatic frequency:

AssertionError: N and T are not of same shape, as they are of dimensions 3 and 2

And in the case of a single angle:

IndexError: tuple index out of range

I think that both of these problems could be fixed by removing unnecessary squeeze() from the converter function in vectorized_tmm_dispersive_multistack module:

def converter(data, device):
    if type(data) is not torch.Tensor:
        if type(data) is np.ndarray:
            data = torch.from_numpy(data.copy())
        else:
            raise ValueError('At least one of the inputs (i.e. N, Theta, ...) is not of type numpy.array or torch.Tensor!')
    data = data.type(torch.cfloat).to(device)
    return data.squeeze()