wesselb / neuralprocesses

A framework for composing Neural Processes in Python
https://wesselb.github.io/neuralprocesses
MIT License
76 stars 12 forks source link

Concatenation error when calling `merge_contexts` with zero-size context set #16

Closed tom-andersson closed 1 year ago

tom-andersson commented 1 year ago

Calling merge_contexts to concatenate two context sets fails when one of those context sets contains zero observations (i.e. has a length-0 data/observation dimension).

MWE: https://colab.research.google.com/drive/1P_9LpXGgX21E72p2nvTnLvmrOzx_Zyjm#scrollTo=hSuekNzNLRQB

import neuralprocesses.tensorflow as nps
import tensorflow as tf
import lab.tensorflow as B
import time

# This raises an error from concatenation
nps.merge_contexts(  # 1st context set has 0 observastion, 2nd has 63 observations
    (B.randn(tf.float32, 1, 2, 0), B.randn(tf.float32, 1, 1, 0)), (B.randn(tf.float32, 1, 2, 0), B.randn(tf.float32, 1, 1, )),
)

Produces:

9 frames
[/usr/local/lib/python3.10/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    390             method, return_type = self.resolve_method(args, types)
    391 
--> 392         return _convert(method(*args, **kw_args), return_type)
    393 
    394     def invoke(self, *types):

[/usr/local/lib/python3.10/dist-packages/neuralprocesses/mask.py](https://localhost:8080/#) in merge_contexts(*contexts, **kw_args)
    114     xcs, ycs = zip(*contexts)
    115     xcs = tuple((xc,) for xc in xcs)  # Pack inputs.
--> 116     xc, yc = merge_contexts(*zip(xcs, ycs), **kw_args)
    117     return xc[0], yc  # Unpack inputs.

[/usr/local/lib/python3.10/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    390             method, return_type = self.resolve_method(args, types)
    391 
--> 392         return _convert(method(*args, **kw_args), return_type)
    393 
    394     def invoke(self, *types):

[/usr/local/lib/python3.10/dist-packages/neuralprocesses/mask.py](https://localhost:8080/#) in merge_contexts(multiple, *contexts)
    105 
    106     return (
--> 107         tuple(B.concat(*xcsi, axis=0) for xcsi in zip(*xcs)),
    108         Masked(B.concat(*ycs, axis=0), B.concat(*masks, axis=0)),
    109     )

[/usr/local/lib/python3.10/dist-packages/neuralprocesses/mask.py](https://localhost:8080/#) in <genexpr>(.0)
    105 
    106     return (
--> 107         tuple(B.concat(*xcsi, axis=0) for xcsi in zip(*xcs)),
    108         Masked(B.concat(*ycs, axis=0), B.concat(*masks, axis=0)),
    109     )

[/usr/local/lib/python3.10/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    390             method, return_type = self.resolve_method(args, types)
    391 
--> 392         return _convert(method(*args, **kw_args), return_type)
    393 
    394     def invoke(self, *types):

[/usr/local/lib/python3.10/dist-packages/lab/shape.py](https://localhost:8080/#) in f_wrapped(*args, **kw_args)
    183         @wraps(f)
    184         def f_wrapped(*args, **kw_args):
--> 185             return f(*(unwrap_dimension(arg) for arg in args), **kw_args)
    186 
    187         return dispatch(f_wrapped)

[/usr/local/lib/python3.10/dist-packages/lab/tensorflow/shaping.py](https://localhost:8080/#) in concat(axis, *elements)
     68 @dispatch
     69 def concat(*elements: Numeric, axis: Int = 0):
---> 70     return tf.concat(elements, axis=axis)
     71 
     72 

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py](https://localhost:8080/#) in raise_from_not_ok_status(e, name)
   7260 def raise_from_not_ok_status(e, name):
   7261   e.message += (" name: " + name if name is not None else "")
-> 7262   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   7263 
   7264 

InvalidArgumentError: {{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Dimension 2 in both shapes must be equal: shape[0] = [1,2,0] vs. shape[1] = [1,2,63] [Op:ConcatV2] name: concat
wesselb commented 1 year ago

@tom-andersson thanks for opening this issue! This appears to be an off-by-one error on my part. I'm on it and will push a fix soon :)

wesselb commented 1 year ago

Should be fixed in the latest release v0.2.2.

tom-andersson commented 1 year ago

Fantastic, thanks for the rapid fix @wesselb! Now works on my side over at DeepSensor too.