pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.84k stars 21.7k forks source link

[fx] symbolic trace "is None" and "is not None" checks #45685

Open jerryzh168 opened 3 years ago

jerryzh168 commented 3 years ago

there are many is None or is not None in https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py, for example: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py#L193

if input_ids is not None:
    input_shape = input_ids.size()
else:
    input_shape = inputs_embeds.size()[:-1]

What is the best way to workaround these?

cc @ezyang @SherlockNoMad

suo commented 3 years ago

This is data-dependent control flow and thus not supported by symbolic tracing. The recommendation is to change the model code to not have data-dependent control flow, or to treat this module as a leaf module in the tracer so we don't trace through it.

jerryzh168 commented 3 years ago

I'm actually asking whether we can add some special support for the this case so that the rewrite is easier, since there are many occurrences for this pattern, e.g. 'is not None' occurs 52 times in https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py for example, we added torch.Assert to replace assert

Maybe https://github.com/pytorch/pytorch/issues/45682 is related?

suo commented 3 years ago

I don't think there is any special support we can add for this case. assert is relatively easy to add support for, because the control flow is always the same; it essentially translates to:

if not assert_condition:
    raise AssertionError("blah")

However, if x is None and if x is not None is indistinguishable from any other control flow; anything can happen in the if and else branches. In cases where x is derived from the inputs, the function behavior is truly dependent on what the inputs are, and thus symbolic tracing cannot support it generically.

jerryzh168 commented 3 years ago

I think most of them are used to initialize some values, can we just support that specific use case instead of supporting all possible cases using 'is not None'/'is None'? x is typically just the one of the optional input, not derived from input. e.g.

def forward(x, optional_value=None):
    if optional_value is None:
         optional_value = torch.zeros(...)
    ...
suo commented 3 years ago

Here is a relevant snippet of control flow from huggingface bert:

        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

Conceivably, the pattern where we are filling default arguments could be contained in a single statement like torch.Assert—something like:

torch.value_or(inputs_embeds, self.word_embeddings(input_ids))

But it would probably be quite limited, and we couldn't handle the first case, so the function would need to be rewritten anyway.

jerryzh168 commented 3 years ago

Here is a relevant snippet of control flow from huggingface bert:

        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

Conceivably, the pattern where we are filling default arguments could be contained in a single statement like torch.Assert—something like:

torch.value_or(inputs_embeds, self.word_embeddings(input_ids))

But it would probably be quite limited, and we couldn't handle the first case, so the function would need to be rewritten anyway.

Yeah, that's similar to what I have in mind. for

if input_ids is not None:
    input_shape = input_ids.size()
else:
    input_shape = inputs_embeds.size()[:-1]

can we do

input_shape = torch.either_or(input_ids.isNone().not(), input_ids.size(), inputs_embeds.size()[:-1])
suo commented 3 years ago

This is conceivable, but starts significantly expanding the scope of FX. In addition, input_ids.isNone().not() is not an API that generically works across values in Python or even PyTorch, so it's not clear how we would express this predicate in the FX generated code.

My inclination is to say we should not support any general control flow constructs for the initial release of FX, and keep it pretty restricted to assert and maybe value_or. @jamesr66a, what do you think?

jamesr66a commented 3 years ago

My inclination is to say we should not support any general control flow constructs for the initial release of FX, and keep it pretty restricted to assert and maybe value_or. @jamesr66a, what do you think?

Sounds fine to me for the initial release.

IMO building up a set of functional primitives (assert, null coalescing, maybe higher order loops or branches) is fine, but where I have doubts is:

IMO with can split the difference here by building out a more sophisticated frontend over time (AST rewrites, bytecode analyses, etc) and aggressively converting mutable constructs into functional forms. That way users are happy and backends don't have to bang their heads against the desk dealing with a mutable program representation. But the key here is that we need to resist the inevitable demand to build a bunch of half-finished crap so some team can hit a fake deadline and actually design this well

jerryzh168 commented 3 years ago
  • Are users going to bother rewriting their code with these?

The goal is to make sure it is as easy as possible for user to symbolically trace their model, I feel this is less intrusive than moving these code to leaf modules, is that true?

If the rewrite can be done automatically (e.g. with AST rewrite), that would be even better, but I think building these things(torch.Assert, torch.value_or etc.) is a prerequisite for that right?

bhack commented 1 month ago

What is the status of this?

I had multiple recompilation with

expected type of 'L['curr_id_emb']' to be a tensor type, ' but found <class 'NoneType'>

Using

        if curr_id_emb is not None:
            global_K, global_V = self.fuse_key_value_id(
                curr_K, curr_V, curr_id_emb)
            if self.d_att is not None:
                global_K = self.linear_Kd(global_K)
            local_K = seq_to_2d(global_K, size_2d)
            local_V = seq_to_2d(global_V, size_2d)

            if self.global_dilation>1 and self.memory_dilation:
                nhw,bs,ck = global_K.shape
                cv = global_V.shape[-1]
                # n = nhw // (size_2d[0] * size_2d[1])
                d = self.global_dilation
                if self.conv_dilation:
                    unfold_K = global_K.permute(1,2,0).reshape(bs,ck,size_2d[0],size_2d[1])
                    unfold_V = global_V.permute(1,2,0).reshape(bs,cv,size_2d[0],size_2d[1])
                    global_K = self.dilation_conv_K(unfold_K).reshape(bs,ck,-1).permute(2,0,1)
                    global_V = self.dilation_conv_V(unfold_V).reshape(bs,cv,-1).permute(2,0,1)
                else:
                    unfold_K = global_K.view(size_2d[0],size_2d[1],bs,ck)
                    unfold_V = global_V.view(size_2d[0],size_2d[1],bs,cv)
                    global_K = unfold_K[::d,::d,:,:].reshape(-1,bs,ck)
                    global_V = unfold_V[::d,::d,:,:].reshape(-1,bs,cv)
        else:
            global_K, global_V = long_term_memory
            local_K, local_V = short_term_memory