apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

TopK returns float32 indices #19753

Open Zha0q1 opened 3 years ago

Zha0q1 commented 3 years ago

It looks like regardless of the input type the indices output B[1] is always of type float32. I checked the c++ code and I think it's supposed to be int32/64? Is this behavior by design?

>>> A = np.array([1,2,3], dtype='float16')
>>> B = npx.topk(A, k=3, ret_typ='both')
>>> B[0].dtype
dtype('float16')
>>> B[1].dtype
dtype('float32')
inline bool TopKType(const nnvm::NodeAttrs& attrs,
                     std::vector<int> *in_attrs,
                     std::vector<int> *out_attrs) {
  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
  size_t in_size = in_attrs->size();
  size_t out_size = out_attrs->size();
  CHECK_EQ(in_size, 1);
  CHECK(out_size == 1 || out_size == 2);
  //  out_attr[0] -> stores value
  //  out_attr[1] -> stores indices
  if (out_size > 1) {
    if (param.ret_typ == topk_enum::kReturnValue) {
#if MXNET_USE_INT64_TENSOR_SIZE == 1
      CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
#else
      CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
#endif
          << "Failed to set the type of ret_indices.";
    } else {
      CHECK(type_assign(&(*out_attrs)[1], param.dtype))
          << "Failed to set the type of ret_indices.";
    }
  }
  if (param.ret_typ == topk_enum::kReturnIndices) {
    CHECK(type_assign(&(*out_attrs)[0], param.dtype))
            << "Failed to set the type of ret_indices.";
  } else {
    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
    TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
    return out_attrs->at(0) != -1;
  }
  return true;
}

@access2rohit @leezu @sxjscience @szha

sxjscience commented 3 years ago

I think we should set the B[1] to int32/int64.

Zha0q1 commented 3 years ago

I think we should set the B[1] to int32/int64.

I just realized there is a parameter dtype in topk to control the type of the indices output

Also had a quick discussion with Rohit. We might force indices output to use int32/64 in 2.0. In v.1x we might want to keep the current behavior so that we do not break things

sxjscience commented 3 years ago

Yes, +1 for forcing the change in 2.0.