tensorly / viz

Easy visualization and evaluation of matrix and tensor factorization models
https://tensorly.org/viz/
MIT License
16 stars 4 forks source link

Mode information in component comparison plot #3

Closed caglayantuna closed 2 years ago

caglayantuna commented 2 years ago

Would it be possible to add an option to change x_label in component_comparison_plot function? For example, tensorly has some datasets (and will have more soon probably) with dims information.

I am planning to add some benchmark functions to tensorly with tlviz package, it would be nice to use these information in the figure.

MarieRoald commented 2 years ago

While writing a response, I found a bug that should hopefully be fixed now. Thanks for helping spot it!

If I understand what you want correctly, you want a reasonably labelled x-axis. TLViz natively supports this if you store the data as an xarray DataArray (Due to a bug, this was not working earlier, but it is fixed now). For the TensorLy datasets, I believe you could do something like this (I haven't looked into the data and I don't know if the masking/no. components make sense, so apologies if these components look weird, I just wanted to make a quick example)

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tlviz
import xarray as xr

from tensorly.datasets import IL2data
from tensorly.decomposition import parafac

sns.set()

# Loading data and storing it in an xarray DataArray
il2 = IL2data()
dataset = xr.DataArray(
    il2.tensor,
    coords={dim: coord for dim, coord in zip(il2.dims, il2.ticks)},
    dims=il2.dims
)
mask = ~np.isnan(dataset.data)
dataset = dataset.fillna(0)

# Training and postprocessing two parafac models
cp = parafac(dataset.data, rank=3,)
cp_postprocessed = tlviz.postprocessing.postprocess(cp, dataset)

cp_masked = parafac(dataset.data, rank=3, mask=mask)
cp_masked_postprocessed = tlviz.postprocessing.postprocess(cp_masked, dataset)

# Plotting the components
fig, axes = tlviz.visualisation.component_comparison_plot({'With mask': cp_masked_postprocessed, 'Without mask': cp_postprocessed})

# Rotating the x-ticks due to long labels
axes[1, 0].xaxis.set_tick_params(rotation=90)
axes[1, 3].xaxis.set_tick_params(rotation=90)

plt.show()
caglayantuna commented 2 years ago

Thanks @MarieRoald for your answer. It helps a lot. This is exactly what I want for dims and you also gave me an idea to use ticks.