giotto-ai / giotto-tda

A high-performance topological machine learning toolbox in Python
https://giotto-ai.github.io/gtda-docs
Other
858 stars 175 forks source link

Modify `PlotterMixin.transform_plot` to give a dictionary to the call to `plot` #484

Closed ulupo closed 4 years ago

ulupo commented 4 years ago

Reference issues/PRs Fixes a bug introduced in #453.

Types of changes

Description In #453, automatic titles were introduced for several (but not all) plot methods. However, this leads to incorrectly reported sample indices when coming from transform_plot of the latter's internal logic. This PR keeps titles but fixes them by modifying the transform_plot method of PlotterMixins so that plot is called not on an array of length 1 but on a dictionary with a single key:

        Xt = self.transform(X[sample:sample+1])
        self.plot(Xt, sample=0, **plot_params).show()

        return Xt

becomes

        Xt = self.transform(X[sample:sample+1])
        self.plot({sample: Xt[0]}, sample=sample, **plot_params).show()

        return Xt

Checklist

ulupo commented 4 years ago

Another "solution" would be to make the outputs of transform in the code for transform_plot into dictionaries with the single key sample (which could be non-zero). Not sure if more or less elegant! It would at least remove the need for sample_orig completely...

wreise commented 4 years ago

Another "solution" would be to make the outputs of transform in the code for transform_plot into dictionaries with the single key sample (which could be non-zero). Not sure if more or less elegant! It would at least remove the need for sample_orig completely...

I don't understand why you wouldn't need sample_orig. Wouldn't you need to pass the key to the VectorisationTransformer.plot method anyway?

ulupo commented 4 years ago

@wreise to make it clearer, this is the proposed fix: the code for transform_plot would be

        Xt = self.transform(X[sample:sample+1])
        self.plot({sample: Xt[0]}, sample=sample, **plot_params).show()

        return Xt

Then you have to change nothing else, I claim. At least in the majority of cases...

wreise commented 4 years ago

@wreise to make it clearer, this is the proposed fix: the code for transform_plot would be

        Xt = self.transform(X[sample:sample+1])
        self.plot({sample: Xt[0]}, sample=sample, **plot_params).show()

        return Xt

IMO, this looks better, even though we feed a dict instead of a ndarray. The "problem" appears only when joining transform on a single sample and plot, and it should be fixed there.

Then you have to change nothing else, I claim. At least in the majority of cases...

Maybe I can test that?

ulupo commented 4 years ago

Maybe I can test that?

Oh please!! :) Feel free to take over this PR. I'm happy to proceed with the dict route.