pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
475 stars 23 forks source link

Creating block mask with mask mod and _compile=True #40

Closed johng149 closed 1 month ago

johng149 commented 2 months ago

I am following along with the blog post regarding using Flex Attention with Document Masking, using torch version 2.6.0.dev20240915+cu118 and I have written the following:

def test_create_block_compiled():
    emb_dim = 64
    num_heads = 2
    dropout = 0.1
    device = "cuda"
    batch_size = 1
    seq_len = 10

    doc_1_len = 6
    doc_2_len = seq_len - doc_1_len

    doc_ids1 = torch.zeros(doc_1_len, dtype=torch.long, device=device)
    doc_ids2 = torch.ones(doc_2_len, dtype=torch.long, device=device)
    doc_ids = torch.cat([doc_ids1, doc_ids2], dim=0)

    def mask_fn(b, h, q_idx, kv_idx):
        return doc_ids[q_idx] == doc_ids[kv_idx]

    q_len = seq_len
    kv_len = seq_len

    mask = create_block_mask(
        mask_mod=mask_fn,
        B=None,
        H=None,
        Q_LEN=q_len,
        KV_LEN=kv_len,
        BLOCK_SIZE=(q_len, kv_len),
        device=device,
        _compile=True,
    )

However, setting _compile to True causes an error when trying to create the block mask while setting it to False works without problems. The trace is:

.venv/lib64/python3.12/site-packages/torch/nn/attention/flex_attention.py:851: in create_block_mask
    partial_block_mask, full_block_mask = inner_func(
.venv/lib64/python3.12/site-packages/torch/_dynamo/eval_frame.py:465: in _fn
    return fn(*args, **kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/convert_frame.py:1292: in __call__
    return self._torchdynamo_orig_callable(
.venv/lib64/python3.12/site-packages/torch/_dynamo/convert_frame.py:530: in __call__
    return _compile(
.venv/lib64/python3.12/site-packages/torch/_dynamo/convert_frame.py:933: in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
.venv/lib64/python3.12/site-packages/torch/_dynamo/convert_frame.py:675: in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
.venv/lib64/python3.12/site-packages/torch/_utils_internal.py:85: in wrapper_function
    return function(*args, **kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/convert_frame.py:708: in _compile_inner
    out_code = transform_code_object(code, transform)
.venv/lib64/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py:1322: in transform_code_object
    transformations(instructions, code_options)
.venv/lib64/python3.12/site-packages/torch/_dynamo/convert_frame.py:220: in _fn
    return fn(*args, **kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/convert_frame.py:643: in transform
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2776: in run
    super().run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2275: in CALL
    self._call(inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2269: in _call
    self.call_function(fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2275: in CALL
    self._call(inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2269: in _call
    self.call_function(fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py:1489: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py:1489: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py:1489: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py:1489: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1676: in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:826: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:326: in call_function
    return super().call_function(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/functions.py:111: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:832: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2991: in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3119: in inline_call_
    tracer.run()
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:979: in run
    while self.step():
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:891: in step
    self.dispatch_table[inst.opcode](self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:569: in wrapper
    return inner_fn(self, inst)
.venv/lib64/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:288: in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/builtin.py:967: in call_function
    return handler(tx, args, kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/builtin.py:943: in _handle_insert_op_in_graph
    return wrap_fx_proxy(tx, proxy)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/builder.py:2045: in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
.venv/lib64/python3.12/site-packages/torch/_dynamo/variables/builder.py:2132: in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
.venv/lib64/python3.12/site-packages/torch/_dynamo/utils.py:2101: in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
.venv/lib64/python3.12/site-packages/torch/_dynamo/utils.py:2036: in get_fake_value
    ret_val = wrap_fake_exception(
.venv/lib64/python3.12/site-packages/torch/_dynamo/utils.py:1595: in wrap_fake_exception
    return fn()
.venv/lib64/python3.12/site-packages/torch/_dynamo/utils.py:2037: in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
.venv/lib64/python3.12/site-packages/torch/_dynamo/utils.py:2169: in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tracer = <torch._dynamo.output_graph.OutputGraph object at 0x7f59bf13fd40>
node = getitem
args = (FakeTensor(..., device='cuda:0', size=(10,), dtype=torch.int64), BatchedTensor(lvl=3, bdim=0, value=
    FakeTensor(..., device='cuda:0', size=(10,), dtype=torch.int64)
))
kwargs = {}, nnmodule = None

    def run_node(tracer, node, args, kwargs, nnmodule):
        """
        Runs a given node, with the given args and kwargs.

        Behavior is dictated by a node's op.

        run_node is useful for extracting real values out of nodes.
        See get_real_value for more info on common usage.

        Note: The tracer arg is only used for 'get_attr' ops
        Note: The nnmodule arg is only used for 'call_module' ops

        Nodes that are not call_function, call_method, call_module, or get_attr will
        raise an AssertionError.
        """
        op = node.op

        with set_current_node(node):

            def make_error_message(e):
                return f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" + str(e)

            try:
                if op == "call_function":
>                   return node.target(*args, **kwargs)
E                   torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function getitem>(*(FakeTensor(..., device='cuda:0', size=(10,), dtype=torch.int64), BatchedTensor(lvl=3, bdim=0, value=
E                       FakeTensor(..., device='cuda:0', size=(10,), dtype=torch.int64)
E                   )), **{}):
E                   vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

I am not sure why this is happening. Perhaps I should use a custom BlockMask constructor or avoid using _compile=True?

Chillee commented 1 month ago

This appears to be a regression that slipped through. We're currently looking at it.

cc: @anijain2305

Chillee commented 1 month ago

This should be fixed on nightly though. Let us know if you still hit it!

johng149 commented 1 month ago

Thanks, I updated to the latest nightly and it works