pyxem / kikuchipy

Toolbox for analysis of electron backscatter diffraction (EBSD) patterns
https://kikuchipy.org
GNU General Public License v3.0
79 stars 30 forks source link

ValueError from Dask in EBSD.refine_orientation() in develop branch #594

Closed hakonanes closed 1 year ago

hakonanes commented 1 year ago

In the development version, I got the following error when calling EBSD.refine_orientation() on an EBSD signal of data shape (784, 1121, 120, 120) with a crystal map from dictionary indexing and a varying projection center:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/highlevelgraph.py:781, in HighLevelGraph.get_all_external_keys(self)
    780 try:
--> 781     return self._all_external_keys
    782 except AttributeError:

AttributeError: 'HighLevelGraph' object has no attribute '_all_external_keys'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[57], line 1
----> 1 xmap_ref_fe = s.refine_orientation(
      2     xmap=xmap_di_fe,
      3     master_pattern=mp_fe,
      4     rechunk=True,
      5     chunk_kwargs=dict(chunk_shape=1000),
      6     **ref_kw
      7 )

File ~/kode/kikuchipy/kikuchipy/signals/ebsd.py:2133, in EBSD.refine_orientation(self, xmap, detector, master_pattern, energy, navigation_mask, signal_mask, method, method_kwargs, trust_region, initial_step, rtol, maxeval, compute, rechunk, chunk_kwargs)
   2120 points_to_refine = self._check_refinement_parameters(
   2121     xmap=xmap,
   2122     detector=detector,
   (...)
   2125     signal_mask=signal_mask,
   2126 )
   2127 patterns, signal_mask = self._prepare_patterns_for_refinement(
   2128     points_to_refine=points_to_refine,
   2129     signal_mask=signal_mask,
   2130     rechunk=rechunk,
   2131     chunk_kwargs=chunk_kwargs,
   2132 )
-> 2133 return _refine_orientation(
   2134     xmap=xmap,
   2135     detector=detector,
   2136     master_pattern=master_pattern,
   2137     energy=energy,
   2138     patterns=patterns,
   2139     points_to_refine=points_to_refine,
   2140     signal_mask=signal_mask,
   2141     method=method,
   2142     method_kwargs=method_kwargs,
   2143     trust_region=trust_region,
   2144     initial_step=initial_step,
   2145     rtol=rtol,
   2146     maxeval=maxeval,
   2147     compute=compute,
   2148     navigation_mask=navigation_mask,
   2149 )

File ~/kode/kikuchipy/kikuchipy/indexing/_refinement/_refinement.py:384, in _refine_orientation(xmap, detector, master_pattern, energy, patterns, points_to_refine, signal_mask, trust_region, rtol, method, method_kwargs, initial_step, maxeval, compute, navigation_mask)
    381 print(msg)
    383 if compute:
--> 384     res = compute_refine_orientation_results(
    385         results=res,
    386         xmap=xmap,
    387         master_pattern=master_pattern,
    388         navigation_mask=navigation_mask,
    389     )
    391 return res

File ~/kode/kikuchipy/kikuchipy/indexing/_refinement/_refinement.py:93, in compute_refine_orientation_results(results, xmap, master_pattern, navigation_mask)
     91 time_start = time()
     92 with ProgressBar():
---> 93     computed_results = results.compute()
     94 total_time = time() - time_start
     95 patterns_per_second = nav_size_in_data / total_time

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/base.py:315, in DaskMethodsMixin.compute(self, **kwargs)
    291 def compute(self, **kwargs):
    292     """Compute this dask collection
    293 
    294     This turns a lazy Dask collection into its in-memory equivalent.
   (...)
    313     dask.base.compute
    314     """
--> 315     (result,) = compute(self, traverse=False, **kwargs)
    316     return result

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/base.py:594, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    586     return args
    588 schedule = get_scheduler(
    589     scheduler=scheduler,
    590     collections=collections,
    591     get=get,
    592 )
--> 594 dsk = collections_to_dsk(collections, optimize_graph, **kwargs)
    595 keys, postcomputes = [], []
    596 for x in collections:

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/base.py:367, in collections_to_dsk(collections, optimize_graph, optimizations, **kwargs)
    365 for opt, val in groups.items():
    366     dsk, keys = _extract_graph_and_keys(val)
--> 367     dsk = opt(dsk, keys, **kwargs)
    369     for opt_inner in optimizations:
    370         dsk = opt_inner(dsk, keys, **kwargs)

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/array/optimization.py:50, in optimize(dsk, keys, fuse_keys, fast_functions, inline_functions_fast_functions, rename_fused_keys, **kwargs)
     48 dsk = optimize_blockwise(dsk, keys=keys)
     49 dsk = fuse_roots(dsk, keys=keys)
---> 50 dsk = dsk.cull(set(keys))
     52 # Perform low-level fusion unless the user has
     53 # specified False explicitly.
     54 if config.get("optimization.fuse.active") is False:

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/highlevelgraph.py:938, in HighLevelGraph.cull(self, keys)
    934 from dask.layers import Blockwise
    936 keys_set = set(flatten(keys))
--> 938 all_ext_keys = self.get_all_external_keys()
    939 ret_layers: dict = {}
    940 ret_key_deps: dict = {}

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/highlevelgraph.py:788, in HighLevelGraph.get_all_external_keys(self)
    783 keys: set = set()
    784 for layer in self.layers.values():
    785     # Note: don't use `keys |= ...`, because the RHS is a
    786     # collections.abc.Set rather than a real set, and this will
    787     # cause a whole new set to be constructed.
--> 788     keys.update(layer.get_output_keys())
    789 self._all_external_keys = keys
    790 return keys

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/blockwise.py:543, in Blockwise.get_output_keys(self)
    537     return {(self.output, *p) for p in self.output_blocks}
    539 # Return all possible output keys (no culling)
    540 return {
    541     (self.output, *p)
    542     for p in itertools.product(
--> 543         *[range(self.dims[i]) for i in self.output_indices]
    544     )
    545 }

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/blockwise.py:543, in <listcomp>(.0)
    537     return {(self.output, *p) for p in self.output_blocks}
    539 # Return all possible output keys (no culling)
    540 return {
    541     (self.output, *p)
    542     for p in itertools.product(
--> 543         *[range(self.dims[i]) for i in self.output_indices]
    544     )
    545 }

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/blockwise.py:503, in Blockwise.dims(self)
    499 """Returns a dictionary mapping between each index specified in
    500 `self.indices` and the number of output blocks for that indice.
    501 """
    502 if not hasattr(self, "_dims"):
--> 503     self._dims = _make_dims(self.indices, self.numblocks, self.new_axes)
    504 return self._dims

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/blockwise.py:1711, in _make_dims(indices, numblocks, new_axes)
   1707 def _make_dims(indices, numblocks, new_axes):
   1708     """Returns a dictionary mapping between each index specified in
   1709     `indices` and the number of output blocks for that indice.
   1710     """
-> 1711     dims = broadcast_dimensions(indices, numblocks)
   1712     for k, v in new_axes.items():
   1713         dims[k] = len(v) if isinstance(v, tuple) else 1

File ~/miniconda3/envs/kp-dev/lib/python3.10/site-packages/dask/blockwise.py:1702, in broadcast_dimensions(argpairs, numblocks, sentinels, consolidate)
   1699     return toolz.valmap(consolidate, g2)
   1701 if g2 and not set(map(len, g2.values())) == {1}:
-> 1702     raise ValueError("Shapes do not align %s" % g)
   1704 return toolz.valmap(toolz.first, g2)

ValueError: Shapes do not align {'.1': {152, 98}, '.0': {1}}

The signal chunk shape was (98, 66, 120, 120), so the 98 in the reported chunks must be from here. Don't know where the 152 comes from, though.

Don't know if this is an issue with the current release as well.

hakonanes commented 1 year ago

The issue was differing numblocks in Dask arrays passed to dask.array.map_blocks(): the patterns array had 152 elements in each block, while the remaining arrays (rotations, PCs etc.) had 98. Fixed in #597.