hidet-org / hidet

An open-source efficient deep learning framework/compiler, written in python.
https://hidet.org
Apache License 2.0
634 stars 50 forks source link

[Fixbug] Fix dynamic memcpy bug #427

Closed KTong821 closed 5 months ago

KTong821 commented 5 months ago

Minimal failure case:

resize_inputs: Tensor = symbol([1, 3, "h", "w"], dtype="int32", device="cpu")
resize_outputs = self.resize(resize_inputs.to(self.dtype, self.device))  # (float32, cuda)
resize_graph: FlowGraph = trace_from(resize_outputs, resize_inputs)

resize_graph.build()

compiles this launch where symbols h and w are undefined.

DLL void hidet_launch_0(float * __restrict__ x, float * __restrict__ y) {
  cudaMemcpyAsync(y, x, (4 * ((3 * h) * w)), cudaMemcpyHostToDevice, (cudaStream_t)get_cuda_stream());
}

Fix is to add exprs to BlackBoxStmt so that symbols defined in exprs can be visited during codegen.

yaoyaoding commented 5 months ago

Thanks @KTong821 !