materialsproject / pymatgen

Python Materials Genomics (pymatgen) is a robust materials analysis code that defines classes for structures and molecules with support for many electronic structure codes. It powers the Materials Project.
https://pymatgen.org
Other
1.52k stars 867 forks source link

[BUG] Incorrect usage of `np.all() > tol` for element-wise value comparison #4166

Closed DanielYang59 closed 4 days ago

DanielYang59 commented 1 week ago

The following method: https://github.com/materialsproject/pymatgen/blob/bd9fba9ec62437b5b62fbd0b2c2c723216cc5a2c/src/pymatgen/io/lobster/outputs.py#L1694-L1737

Used a pattern like np.array().all() > tolerance with the intention to perform element-wise value comparison, however this seems to be misused as all():

"Test whether all array elements along a given axis evaluate to True."

import numpy as np

arr = np.array([1, 2, 3])

print(arr.all())  # True

print(np.array([False, 1, 2]).all())  # False

Should be something like:

- if abs(band2 - 1.0).all() > limit_deviation:
# However this logic might need double check as it only evaluates to True 
# when ALL deviations are greater than threshold
+ if np.all(np.abs(band2 - 1) > limit_deviation):
# Perhaps it should be the following?
+ if np.any(np.abs(band2 - 1) > limit_deviation):

# The same to the other 
- elif band2.all() > limit_deviation:
+ elif np.all(band2 > limit_deviation)
# The following seems more sensible?
+ elif np.any(band2 > limit_deviation)

I might go ahead and change this if I didn't misunderstand anything? @JaGeo

JaGeo commented 1 week ago

@naik-aakash Could you kindly check this (after Thursday ;))? Thanks a lot.

naik-aakash commented 1 week ago

Hi @DanielYang59 , thanks for pointing this out. Indeed there seems to have been an error with comparison logic here. Using np.any should be preferable here as you suggested I think.

I assume you have already started working on this issue, if not am also happy to raise a PR for this. Let me know.

DanielYang59 commented 1 week ago

Hi @naik-aakash greetings thanks for looking into this :)

Using np.any should be preferable here as you suggested I think.

I might have caused confusion here, sorry. The main issue is not the usage of any() over all(), but the order of comparison vs all/any call. And we should do comparison before all/any evaluation:

Test whether all array elements along a given axis evaluate to True.

Current implementation does the other way round, i.e. calling array.all() would evaluate the array to a bool, and we're basically comparing the bool (0/1) with limit_deviation.


I assume you have already started working on this issue, if not am also happy to raise a PR for this. Let me know.

I might need your help to fix this, and it's apparently not covered by unit test. Perhaps you could enhance the unit test on top of fixing the implementation? Thank you!

QuantumChemist commented 6 days ago

Hey 😃 ,

I have an additional suggestion to improve this function further. Instead of looping through each value like that, https://github.com/materialsproject/pymatgen/blob/bd9fba9ec62437b5b62fbd0b2c2c723216cc5a2c/src/pymatgen/io/lobster/outputs.py#L1713-L1735 maybe we could aim for a more compact comparison:

import numpy as np

def validate_band_overlaps(matrices, num_occ_bands, limit_deviation):
    """
    Validate diagonal and off-diagonal elements for band overlaps.
    """
    for matrix in matrices:
        matrix = np.array(matrix)
        # Limit matrix to occupied bands
        sub_matrix = matrix[:num_occ_bands, :num_occ_bands]

        # Check diagonal elements
        if np.any(np.abs(np.diag(sub_matrix) - 1.0) > limit_deviation):
            return False

        # Check off-diagonal elements
        if np.any(np.abs(sub_matrix - np.diag(np.diag(sub_matrix))) > limit_deviation):
            return False

    return True

def check_band_overlaps(band_overlaps_dict, number_occ_bands_spin_up, limit_deviation, spin_polarized=False, number_occ_bands_spin_down=None):
    """
    Check band overlaps for spin-up and optionally spin-down configurations.
    """
    # Validate spin-up matrices
    if not validate_band_overlaps(band_overlaps_dict[Spin.up]["matrices"], number_occ_bands_spin_up, limit_deviation):
        return False

    # Validate spin-down matrices if spin-polarized
    if spin_polarized:
        if number_occ_bands_spin_down is None:
            raise ValueError("number_occ_bands_spin_down must be specified.")
        if not validate_band_overlaps(band_overlaps_dict[Spin.down]["matrices"], number_occ_bands_spin_down, limit_deviation):
            return False

    return True

(Disclaimer: this is ChatGPT-generated but looks about like the code structure I would have proposed, maybe one can condense it a little bit).

What do you think about that, @naik-aakash ?

DanielYang59 commented 5 days ago

@QuantumChemist Brilliant idea, I didn't look closely into the code itself at the time of opening this issue, but I fully support your idea to reduce code repetition and nest level.

I'm not quite familiar with the expected type and shape of the matrix from band_overlaps_dict[Spin.up]["matrices"], but looks like it's a 2D square array of floats? If that's the case, perhaps comparing its top-left num_occ_bands x num_occ_bands leading principal sub-array with an identity array? Something like:

if spin_polarized and number_occ_bands_spin_down is None:
    raise ValueError("number_occ_bands_spin_down has to be specified")

for spin in (Spin.up, Spin.down) if spin_polarized else (Spin.up,):
    num_occ_bands = number_occ_bands_spin_up if spin == Spin.up else number_occ_bands_spin_down

    for array in self.band_overlaps_dict[spin]["matrices"]:
        # Looks like it is already np.array?
        sub_array = np.asarray(array)[:num_occ_bands, :num_occ_bands]

        if not np.allclose(sub_array, np.identity(num_occ_bands), atol=limit_deviation, rtol=0):
            return False

return True

I realized I might be misusing the word "matrix" as we're not doing linear algebra operations, I believe "array" might be a better term. I'm mentioning this because NumPy is discouraging the usage of np.matrix even for real linear algebra usage:

It is no longer recommended to use this class, even for linear algebra. Instead use regular arrays. The class may be removed in the future.

naik-aakash commented 5 days ago

Thank you so much, @QuantumChemist and @DanielYang59, for the suggestions to reduce the code complexity. I will implement the suggestions and raise a PR soon

And @DanielYang59 , you are right, it is not using matrix but already an numpy array but the term "matrix" is used in Bandoverlaps class just to ensure consistency with actual output files from LOBSTER.