in-context-learning-2024 / in-context

5 stars 1 forks source link

Advanced, automatic shape checking #50

Open nelson-lojo opened 7 months ago

nelson-lojo commented 7 months ago

Torch is rolling out prototype shape checking as of recent. If we implement this in ContextModel and in FunctionClass, we would get closer to our goal of "eager tests" dispersed throughout the codebase.

I'm not a huge fan of the current method torch is presenting, but if we had to integrate it, we could do it as follows:

In FunctionClass:

class FunctionClass:
    def __init__(self, ...):
        ...
        _dims = {
            "x_batch" : {
                0 : Dim("b_size", min=self.batch_size, max=self.batch_size), 
                1 : Dim("seq_len", min=self.sequence_length, max=self.sequence_length), 
                2 : Dim("x_dim", min=self.x_dim, max=self.x_dim)
            }, "y_batch" : {
                0 : Dim("b_size", min=self.batch_size, max=self.batch_size), 
                1 : Dim("seq_len", min=self.sequence_length, max=self.sequence_length), 
                2 : Dim("y_dim", min=self.y_dim, max=self.y_dim)
            }
        }

        self.type_check = lambda x, y: torch.export(torch.nn.Identity(), (x, y), dynamic_shapes=_dims)

    def produce_x_y_batch(self):
        ... # current body of __next__

    def __next__(self):
        x, y = self.produce_x_y_batch()
        self.type_check(x, y)
        return x, y