pku-liang / FlexTensor

Automatic Schedule Exploration and Optimization Framework for Tensor Computations
MIT License
175 stars 32 forks source link

problems in DQN search #16

Closed Maximilianxu closed 4 years ago

Maximilianxu commented 4 years ago

I ran the optimize/optimize_conv2d.py with the following command:

python3 optimize_conv2d.py --shapes yolo --target cuda --trials 1000 --timeout 10 --parallel 8 --log tmp_log.txt --method q

However, after a lot of warnings about the warm up things, I got this issue

Traceback (most recent call last):

  File "optimize_conv2d.py", line 204, in <module>
    logfile=flog,

  File "optimize_conv2d.py", line 116, in optimize
    rpc_info=rpc_info,

  File "/home/max/workspaces/python/FlexTensor/flextensor/scheduler.py", line 2100, in schedule
    perf_path=perf_path,

  File "/home/max/workspaces/python/FlexTensor/flextensor/scheduler.py", line 670, in schedule
    return self._q_schedule(configs, wanted_types, use_model=use_model)

  File "/home/max/workspaces/python/FlexTensor/flextensor/scheduler.py", line 443, in _q_schedule
    from_lst, next_points, action_lst = self.walker_group.walk(cur_lst, trial)

  File "/home/max/workspaces/python/FlexTensor/flextensor/model.py", line 359, in walk
    next_index_lst, direction_lst = self.walkers[name].walk(flattened_lst, index_lst, trial, epsilon, gamma)

  File "/home/max/workspaces/python/FlexTensor/flextensor/model.py", line 72, in walk
    q_values_lst = self.pre_judger(torch.FloatTensor(inputs)).detach()

  File "/home/max/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)

  File "/home/max/workspaces/python/FlexTensor/flextensor/model.py", line 34, in forward
    out = self.net(inputs)

  File "/home/max/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)

  File "/home/max/.local/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)

  File "/home/max/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)

  File "/home/max/.local/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)

  File "/home/max/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 1372, in linear
    output = input.matmul(weight.t())

RuntimeError: size mismatch, m1: [1 x 0], m2: [22 x 64] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:197

The same issue arised in another op I wrote by myself with "method=q" setting.

Are there any instructions on using DQN search?

KnowingNothing commented 4 years ago

Can you show the warning messages in warm-up process? The reason may be that there is no valid schedule found in warm-up process.

Maximilianxu commented 4 years ago

Can you show the warning messages in warm-up process? The reason may be that there is no valid schedule found in warm-up process.

warning message:

arm up [1592718303.555670] [ inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf ]
warm up [1592718315.373111] [ inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf ]
warm up [1592718327.291838] [ inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf ]
Warning: No valid schedule found in warm up process, please use more trials
Now automatically use more trials, increase 20
warm up [1592718339.389995] [ inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf ]
Warning: No valid schedule found in warm up process, please use more trials
Now automatically use more trials, increase 20

After that:

Fail to find valid schedule, too many errors
warm up [1592718736.218251] [ inf inf inf inf ]
Warning: No valid schedule found in warm up process, please use more trials
Now automatically use more trials, increase 4
warm up [1592718738.672842] [ inf inf inf inf ]
Warning: No valid schedule found in warm up process, please use more trials
Now automatically use more trials, increase 4

code:

if __name__ == '__main__':
  N = 1
  I, O, H, W = 16, 32, 128, 128
  KH, KW = 5, 5
  STRIDE = 1

  def conv2d(inpf, kernel):
    rh, rw, ri = te.reduce_axis((0, KH)), te.reduce_axis((0, KW)), te.reduce_axis((0, I))
    outf = te.compute((N, O, H, W), 
      lambda n, o, h, w: te.sum(inpf[n][STRIDE * h + rh][STRIDE * w + rw][ri] * kernel[o][ri][rh][rw],
      [ri, rh, rw]))
    return outf

  def wrap_conv2d():
    inpf = te.placeholder((N, (H - 1) * STRIDE + KH, (W - 1) * STRIDE + KW, I))
    kernel = te.placeholder((O, I, KH, KW))
    outf = conv2d(inpf, kernel)
    return [outf.op], [inpf, kernel, outf]

  task = Task(
      "conv2d", 
      "conv2d", 
      wrap_conv2d, 
      (), 
      "cuda", 
      0)
  # register the task
  register_task(task)

  from flextensor.scheduler import schedule

  s, bufs, configs = schedule(
              task.key, # give the key of target task
              slevel=4,
              rlevel=3,
              op_trial=1000, 
              timeout=10, 
              op_stop=30, 
              method="searching", 
              parallel=4,
              )

  # directly use the results
  func = tvm.build(s, bufs, task.target)
  # # use the configs
  # from flextensor.scheduler import schedule_with_config

  # s, bufs = schedule_with_config(task_key, configs)
  # func = tvm.build(s, bufs, task.target)

  np_inpf = np.random.uniform(size=(N, (H - 1) * STRIDE + KH, (W - 1) * STRIDE + KW, I))
  np_knl = np.random.uniform(size=(O, I, KH, KW))

  tvm_inpf = tvm.nd.array(np_inpf.astype(inpf.dtype))
  tvm_knl = tvm.nd.array(np_knl.astype(kernel.dtype))
  tvm_outf = tvm.nd.array(np.zeros((N, O, H, W), dtype=outf.dtype))

  evaluator = func.time_evaluator(baseline_conv2d.entry_name, ctx, number=10)
  LOGI('layout NHWI(hw) conv2d: %f ms' % (evaluator(tvm_inpft, tvm_knl, tvm_outft).mean * 1e3))

TVM version is 0.7dev

The warning messages repeat many times and finally show no valid schedule found. Could you please provide some advices on this problem? Thanks a lot.

KnowingNothing commented 4 years ago

You can first try optimize_conv2d.py in flextensor/optimize/ directory. Use this command:

python3 optimize_conv2d.py --shapes yolo --target cuda --parallel 4

If this example can't run normally, there may be some problems in your environment (such as no cuda available).

KnowingNothing commented 4 years ago

I have reproduced your issue. The reason is that FlexTensor uses "spawn" mode for multiprocessing. The spawned processes have no access to your customized task. There are two ways to avoid such issue:

Using "spawn" default in FlexTensor is a temporary compromise, we are still seeking for better solutions.

Maximilianxu commented 4 years ago

I have reproduced your issue. The reason is that FlexTensor uses "spawn" mode for multiprocessing. The spawned processes have no access to your customized task. There are two ways to avoid such issue:

  • write and register your task in task.py just like other pre-defined tasks
  • change scheduler.py line 12 "spawn" to "fork"

Using "spawn" default in FlexTensor is a temporary compromise, we are still seeking for better solutions.

Thanks a lot, it works now.