Closed saudet closed 3 years ago
just checked it. an explicit call should be fine there. there were no other errors.
diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX
index cfd9103430..e9c3bc1042 100644
--- a/libnd4j/include/array/NDArray.hXX
+++ b/libnd4j/include/array/NDArray.hXX
@@ -1218,9 +1218,9 @@ void NDArray::assign(const T& value, bool allowParallelism) {
// just fire scalar
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
- NDArray::prepareSpecialUse({this}, {&temp});
+ NDArray::prepareSpecialUse(std::vector<const NDArray*>{this}, std::vector<const NDArray*>{&temp});
NativeOpExecutioner::execScalar(getContext(), sd::scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.specialShapeInfo(), nullptr, allowParallelism);
- NDArray::registerSpecialUse({this}, {&temp});
+ NDArray::registerSpecialUse(std::vector<const NDArray*>{this}, std::vector<const NDArray*>{&temp});
}
template ND4J_EXPORT void NDArray::assign(const double& value, bool allowParallelism);
template ND4J_EXPORT void NDArray::assign(const float& value, bool allowParallelism);
@quickwritereader Please update my branch!
@quickwritereader If this works for you, please merge! Thanks
Before we can merge this, libnd4j needs to be updated. Currently with CUDA 11.1, we get compiler errors like this: