TRIQS / triqs

a Toolbox for Research on Interacting Quantum Systems
https://triqs.github.io
GNU General Public License v3.0
139 stars 72 forks source link

BUG? MeshBrillouinZone mesh points not implementing addition correctly in Python #557

Closed HugoStrand closed 5 years ago

HugoStrand commented 6 years ago

Dear Nils,

I am trying to add mesh_points from MeshBrillouinZone in Python and get a rather surprising result.

I understand that the sum of two elements k + q not necessarily is on the mesh, but I expected to get a 3-component vector as a result, instead I get a 3-component rank 1 np.ndarray of np.object containing 3-component rank 1 np.ndarrays. See a minimal example below.

Can this be fixed in an easy way?

Best, Hugo

(The workaround k.value + q.value is not very intuitive)

import numpy as np
from pytriqs.gf import MeshBrillouinZone
from pytriqs.lattice import BrillouinZone, BravaisLattice

units = np.eye(3)
periodization_matrix = 6 * np.eye(3, dtype=np.int32)

bl = BravaisLattice(units)
bz = BrillouinZone(bl)

bzmesh = MeshBrillouinZone(bz, periodization_matrix)

for k in bzmesh:
    for q in bzmesh:
        print 'k =', k
        print 'q =', q

        print 'k + q =\n', k + q
        print type(k + q)
        print (k + q).shape
        print (k + q).__repr__()
        exit()

producing

k = mesh_point(linear_index = 0, value = [0. 0. 0.])
q = mesh_point(linear_index = 0, value = [0. 0. 0.])
k + q =
[array([0., 0., 0.]) array([0., 0., 0.]) array([0., 0., 0.])]
<type 'numpy.ndarray'>
(3,)
array([array([0., 0., 0.]), array([0., 0., 0.]), array([0., 0., 0.])],
      dtype=object)
Wentzell commented 6 years ago

The following patch solves this issue @parcollet Please review

diff --git a/pytriqs/gf/mesh_point.py b/pytriqs/gf/mesh_point.py
index 49cbf6f5..1c5bd8c2 100644
--- a/pytriqs/gf/mesh_point.py
+++ b/pytriqs/gf/mesh_point.py
@@ -34,7 +34,10 @@ class MeshPoint :
         self.linear_index, self.value = linear_index, value

     def __add__(self, x):
-        return self.value + x
+        if(isinstance(x, MeshPoint)):
+            return self.value + x.value
+        else:
+            return self.value + x

     def __radd__(self, x):
         return x + self.value
Wentzell commented 5 years ago

Fixed by #557