hail-is / hail

Cloud-native genomic dataframes and batch computing
https://hail.is
MIT License
984 stars 246 forks source link

[query] bug in linear_regression_rows #14594

Open patrick-schultz opened 4 months ago

patrick-schultz commented 4 months ago

What happened?

Reported by https://hail.zulipchat.com/#narrow/stream/123010-Hail-Query-0.2E2-support/topic/Bugs.20of.20weights.20option.20in.20linear_regression_row/near/448000375

Simplified reproducer:

mt = hl.read_matrix_table('data/1kg.mt')
mt = mt.annotate_rows(weights=1)
mt = mt.annotate_cols(y=1)
gwas_weights = hl.linear_regression_rows(
    y=mt.y,
    x=mt.GT.n_alt_alleles(),
    covariates=[1.0],
    weights=mt.weights,
)

Hits a KeyError: 'va' in CSE. Seems to only happen when given a weights argument. Why is this not hit by tests like test_weighted_linear_regression?

Version

0.2.131-37a5ba226bae

Relevant log output

----> 1 gwas_weights = hl._linear_regression_rows_nd(y=mt.y,
      2                                  x=mt.GT.n_alt_alleles(),
      3                                  covariates=[1.0],
      4                                  weights=mt.weights)

File <decorator-gen-1734>:2, in _linear_regression_rows_nd(y, x, covariates, block_size, weights, pass_through)

File ~/hail/hail/python/hail/typecheck/check.py:585, in _make_dec.<locals>.wrapper(__original_func, *args, **kwargs)
    582 @decorator
    583 def wrapper(__original_func: Callable[..., T], *args, **kwargs) -> T:
    584     args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 585     return __original_func(*args_, **kwargs_)

File ~/hail/hail/python/hail/methods/statgen.py:717, in _linear_regression_rows_nd(y, x, covariates, block_size, weights, pass_through)
    714 res = res.select_globals()
    716 temp_file_name = hl.utils.new_temp_file("_linear_regression_rows_nd", "result")
--> 717 res = res.checkpoint(temp_file_name)
    719 return res

File <decorator-gen-1234>:2, in checkpoint(self, output, overwrite, stage_locally, _codec_spec, _read_if_exists, _intervals, _filter_intervals)

File ~/hail/hail/python/hail/typecheck/check.py:585, in _make_dec.<locals>.wrapper(__original_func, *args, **kwargs)
    582 @decorator
    583 def wrapper(__original_func: Callable[..., T], *args, **kwargs) -> T:
    584     args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 585     return __original_func(*args_, **kwargs_)

File ~/hail/hail/python/hail/table.py:1963, in Table.checkpoint(self, output, overwrite, stage_locally, _codec_spec, _read_if_exists, _intervals, _filter_intervals)
   1960 hl.current_backend().validate_file(output)
   1962 if not _read_if_exists or not hl.hadoop_exists(f'{output}/_SUCCESS'):
-> 1963     self.write(output=output, overwrite=overwrite, stage_locally=stage_locally, _codec_spec=_codec_spec)
   1964     _assert_type = self._type
   1965     _load_refs = False

File <decorator-gen-1236>:2, in write(self, output, overwrite, stage_locally, _codec_spec)

File ~/hail/hail/python/hail/typecheck/check.py:585, in _make_dec.<locals>.wrapper(__original_func, *args, **kwargs)
    582 @decorator
    583 def wrapper(__original_func: Callable[..., T], *args, **kwargs) -> T:
    584     args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 585     return __original_func(*args_, **kwargs_)

File ~/hail/hail/python/hail/table.py:2005, in Table.write(self, output, overwrite, stage_locally, _codec_spec)
   1979 """Write to disk.
   1980
   1981 Examples
   (...)
   2000     If ``True``, overwrite an existing file at the destination.
   2001 """
   2003 hl.current_backend().validate_file(output)
-> 2005 Env.backend().execute(
   2006     ir.TableWrite(self._tir, ir.TableNativeWriter(output, overwrite, stage_locally, _codec_spec))
   2007 )

File ~/hail/hail/python/hail/backend/spark_backend.py:227, in SparkBackend.execute(self, ir, timed)
    224     except Exception as fatal:
    225         raise err from fatal
--> 227 raise err

File ~/hail/hail/python/hail/backend/spark_backend.py:219, in SparkBackend.execute(self, ir, timed)
    217 def execute(self, ir: BaseIR, timed: bool = False) -> Any:
    218     try:
--> 219         return super().execute(ir, timed)
    220     except Exception as err:
    221         if self._copy_log_on_error:

File ~/hail/hail/python/hail/backend/backend.py:176, in Backend.execute(self, ir, timed)
    175 def execute(self, ir: BaseIR, timed: bool = False) -> Any:
--> 176     payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed)
    177     try:
    178         result, timings = self._rpc(ActionTag.EXECUTE, payload)

File ~/hail/hail/python/hail/backend/backend.py:193, in Backend._render_ir(self, ir)
    191 def _render_ir(self, ir):
    192     r = CSERenderer()
--> 193     return r(finalize_randomness(ir))

File ~/hail/hail/python/hail/ir/renderer.py:135, in CSERenderer.__call__(self, root)
    134 def __call__(self, root: 'ir.BaseIR') -> str:
--> 135     binding_sites = CSEAnalysisPass(self)(root)
    136     return CSEPrintPass(self)(root, binding_sites)

File ~/hail/hail/python/hail/ir/renderer.py:199, in CSEAnalysisPass.__call__(self, root)
    196 child_frame = frame.make_child_frame(len(stack))
    198 if isinstance(child, ir.IR):
--> 199     bind_depth = child_frame.bind_depth()
    200     lets = None
    201     if bind_depth < len(stack):

File ~/hail/hail/python/hail/ir/renderer.py:309, in CSEAnalysisPass.StackFrame.bind_depth(self)
    307 bind_depth = self.min_binding_depth
    308 if len(self.node.free_vars) > 0:
--> 309     bind_depth = max(bind_depth, *(self.context[0][var] for var in self.node.free_vars))
    310 if len(self.node.free_agg_vars) > 0:
    311     bind_depth = max(bind_depth, *(self.context[1][var] for var in self.node.free_agg_vars))

File ~/hail/hail/python/hail/ir/renderer.py:309, in <genexpr>(.0)
    307 bind_depth = self.min_binding_depth
    308 if len(self.node.free_vars) > 0:
--> 309     bind_depth = max(bind_depth, *(self.context[0][var] for var in self.node.free_vars))
    310 if len(self.node.free_agg_vars) > 0:
    311     bind_depth = max(bind_depth, *(self.context[1][var] for var in self.node.free_agg_vars))

KeyError: 'va'
ehigham commented 4 months ago

next steps: