Open jerryzh168 opened 3 years ago
This is data-dependent control flow and thus not supported by symbolic tracing. The recommendation is to change the model code to not have data-dependent control flow, or to treat this module as a leaf module in the tracer so we don't trace through it.
I'm actually asking whether we can add some special support for the this case so that the rewrite is easier, since there are many occurrences for this pattern, e.g. 'is not None' occurs 52 times in https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py
for example, we added torch.Assert
to replace assert
Maybe https://github.com/pytorch/pytorch/issues/45682 is related?
I don't think there is any special support we can add for this case. assert
is relatively easy to add support for, because the control flow is always the same; it essentially translates to:
if not assert_condition:
raise AssertionError("blah")
However, if x is None
and if x is not None
is indistinguishable from any other control flow; anything can happen in the if and else branches. In cases where x
is derived from the inputs, the function behavior is truly dependent on what the inputs are, and thus symbolic tracing cannot support it generically.
I think most of them are used to initialize some values, can we just support that specific use case instead of supporting all possible cases using 'is not None'/'is None'? x
is typically just the one of the optional input, not derived from input.
e.g.
def forward(x, optional_value=None):
if optional_value is None:
optional_value = torch.zeros(...)
...
Here is a relevant snippet of control flow from huggingface bert:
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Conceivably, the pattern where we are filling default arguments could be contained in a single statement like torch.Assert
—something like:
torch.value_or(inputs_embeds, self.word_embeddings(input_ids))
But it would probably be quite limited, and we couldn't handle the first case, so the function would need to be rewritten anyway.
Here is a relevant snippet of control flow from huggingface bert:
if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = self.position_ids[:, :seq_length] if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids)
Conceivably, the pattern where we are filling default arguments could be contained in a single statement like
torch.Assert
—something like:torch.value_or(inputs_embeds, self.word_embeddings(input_ids))
But it would probably be quite limited, and we couldn't handle the first case, so the function would need to be rewritten anyway.
Yeah, that's similar to what I have in mind. for
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
can we do
input_shape = torch.either_or(input_ids.isNone().not(), input_ids.size(), inputs_embeds.size()[:-1])
This is conceivable, but starts significantly expanding the scope of FX. In addition, input_ids.isNone().not()
is not an API that generically works across values in Python or even PyTorch, so it's not clear how we would express this predicate in the FX generated code.
My inclination is to say we should not support any general control flow constructs for the initial release of FX, and keep it pretty restricted to assert
and maybe value_or
. @jamesr66a, what do you think?
My inclination is to say we should not support any general control flow constructs for the initial release of FX, and keep it pretty restricted to assert and maybe value_or. @jamesr66a, what do you think?
Sounds fine to me for the initial release.
IMO building up a set of functional primitives (assert, null coalescing, maybe higher order loops or branches) is fine, but where I have doubts is:
IMO with can split the difference here by building out a more sophisticated frontend over time (AST rewrites, bytecode analyses, etc) and aggressively converting mutable constructs into functional forms. That way users are happy and backends don't have to bang their heads against the desk dealing with a mutable program representation. But the key here is that we need to resist the inevitable demand to build a bunch of half-finished crap so some team can hit a fake deadline and actually design this well
- Are users going to bother rewriting their code with these?
The goal is to make sure it is as easy as possible for user to symbolically trace their model, I feel this is less intrusive than moving these code to leaf modules, is that true?
If the rewrite can be done automatically (e.g. with AST rewrite), that would be even better, but I think building these things(torch.Assert, torch.value_or etc.) is a prerequisite for that right?
What is the status of this?
I had multiple recompilation with
expected type of 'L['curr_id_emb']' to be a tensor type, ' but found <class 'NoneType'>
Using
if curr_id_emb is not None:
global_K, global_V = self.fuse_key_value_id(
curr_K, curr_V, curr_id_emb)
if self.d_att is not None:
global_K = self.linear_Kd(global_K)
local_K = seq_to_2d(global_K, size_2d)
local_V = seq_to_2d(global_V, size_2d)
if self.global_dilation>1 and self.memory_dilation:
nhw,bs,ck = global_K.shape
cv = global_V.shape[-1]
# n = nhw // (size_2d[0] * size_2d[1])
d = self.global_dilation
if self.conv_dilation:
unfold_K = global_K.permute(1,2,0).reshape(bs,ck,size_2d[0],size_2d[1])
unfold_V = global_V.permute(1,2,0).reshape(bs,cv,size_2d[0],size_2d[1])
global_K = self.dilation_conv_K(unfold_K).reshape(bs,ck,-1).permute(2,0,1)
global_V = self.dilation_conv_V(unfold_V).reshape(bs,cv,-1).permute(2,0,1)
else:
unfold_K = global_K.view(size_2d[0],size_2d[1],bs,ck)
unfold_V = global_V.view(size_2d[0],size_2d[1],bs,cv)
global_K = unfold_K[::d,::d,:,:].reshape(-1,bs,ck)
global_V = unfold_V[::d,::d,:,:].reshape(-1,bs,cv)
else:
global_K, global_V = long_term_memory
local_K, local_V = short_term_memory
there are many is None or is not None in https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py, for example: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py#L193
What is the best way to workaround these?
cc @ezyang @SherlockNoMad