Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

thunder may treat global (maybe nonlocal) value as constant in computation trace without a check in prologue #1464

Open kshitij12345 opened 4 days ago

kshitij12345 commented 4 days ago
import torch
import thunder
from contextvars import ContextVar

_compile_data = ContextVar("compile_data", default=1)

def fn(x):
    v = _compile_data.get()
    return x + v

jfn = thunder.jit(fn)
o = jfn(torch.ones(3,))
print(o)  # tensor([2., 2., 2.])

_compile_data.set((2,))
o = jfn(torch.ones(3,))
print(o)  # tensor([2., 2., 2.]) (should be tensor([3., 3., 3.]))

print(thunder.last_prologue_traces(jfn)[-1])
print(thunder.last_traces(jfn)[-1])

Prologue Trace

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  check_len(args, 1)
    # prims.check_len(args, 1)
  # kwargs: "Any"
  check_len(kwargs, 0)
    # prims.check_len(kwargs, 0)
  x: "cpu f32[3]" = args[0]
  check_tensor_metadata(x, (3,), 'cpu', torch.float32, False)
    # prims.check_tensor_shape_and_metadata(x, (3,), 'cpu', torch.float32, False)
  cache_info: "Any" = thunder._get_cache_info()
  cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
  check_literal_like(cache_info_default_dtype, torch.float32)
    # prims.check_literal_like(cache_info_default_dtype, torch.float32)
  cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
  check_literal_like(cache_info_default_device, torch.device("cpu"))
    # prims.check_literal_like(cache_info_default_device, torch.device("cpu"))
  cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
  check_number_type_and_value(cache_info_is_autocast_enabled, False)
    # prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
  cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
  check_number_type_and_value(cache_info_no_grad_sync, False)
    # prims.check_number_type_and_value(cache_info_no_grad_sync, False)
  cache_info_alias_tensor_indices: "str" = cache_info['alias_tensor_indices']
  check_string_value(cache_info_alias_tensor_indices, '')
    # prims.check_string_value(cache_info_alias_tensor_indices, '')
  cache_info_is_grad_enabled: "bool True" = cache_info['is_grad_enabled']
  check_number_type_and_value(cache_info_is_grad_enabled, True)
    # prims.check_number_type_and_value(cache_info_is_grad_enabled, True)
  return ((x,), ())

Computation Trace

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cpu f32[3]"
  t0 = torch.add(x, 1, alpha=1)  # t0: "cpu f32[3]"
    # t0 = ltorch.add(x, 1, alpha=1)  # t0: "cpu f32[3]"
      # _ = prims.convert_element_type(1, float)
      # t0 = prims.add(x, 1.0)  # t0: "cpu f32[3]"
  return t0
kshitij12345 commented 4 days ago

Relevant Conversation: https://github.com/Lightning-AI/lightning-thunder/pull/1458#discussion_r1852358009