naibaf7 / caffe

Caffe: a fast open framework for deep learning. With OpenCL and CUDA support.
http://caffe.berkeleyvision.org/
Other
85 stars 20 forks source link

Can not compile on OSX 10.11 #13

Closed letalvoj closed 8 years ago

letalvoj commented 8 years ago

Hi,

I tried to compile this code but it does not seem to work. Is there any commit from which I can start and compile the stuff, or is the code unusable yet?

While trying to compile the code I got mostly easy to fix casting errors:

diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp
index 5252fd2..223732e 100644
--- a/src/caffe/layers/base_conv_layer.cpp
+++ b/src/caffe/layers/base_conv_layer.cpp
@@ -26,7 +26,7 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   num_spatial_axes_ = num_axes - first_spatial_axis;
   CHECK_GE(num_spatial_axes_, 0);
   vector<int_tp> bottom_dim_blob_shape(1, num_spatial_axes_ + 1);
-  vector<int_tp> spatial_dim_blob_shape(1, std::max(num_spatial_axes_, 1L));
+  vector<int_tp> spatial_dim_blob_shape(1, std::max(num_spatial_axes_, static_cast<int_tp>(1L)));
   // Setup filter kernel dimensions (kernel_shape_).
   kernel_shape_.Reshape(spatial_dim_blob_shape);
   int_tp* kernel_shape_data = kernel_shape_.mutable_cpu_data();
diff --git a/src/caffe/layers/eltwise_layer.cpp b/src/caffe/layers/eltwise_layer.cpp
index c76f80c..7b53eb7 100644
--- a/src/caffe/layers/eltwise_layer.cpp
+++ b/src/caffe/layers/eltwise_layer.cpp
@@ -66,7 +66,7 @@ void EltwiseLayer<Dtype>::Forward_cpu(
   case EltwiseParameter_EltwiseOp_MAX:
     // Initialize
     mask = max_idx_.mutable_cpu_data();
-    caffe_set(count, -1L, mask);
+    caffe_set(count, -1LL, mask);
     caffe_set(count, Dtype(-FLT_MAX), top_data);
     // bottom 0 & 1
     bottom_data_a = bottom[0]->cpu_data();
diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp
index 6de258c..5bbf009 100644
--- a/src/caffe/layers/pooling_layer.cpp
+++ b/src/caffe/layers/pooling_layer.cpp
@@ -33,7 +33,7 @@ void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_GE(num_spatial_axes_, 0);

   vector<int_tp> bottom_dim_blob_shape(1, num_spatial_axes_ + 1);
-  vector<int_tp> spatial_dim_blob_shape(1, std::max(num_spatial_axes_, 1L));
+  vector<int_tp> spatial_dim_blob_shape(1, std::max(num_spatial_axes_, static_cast<int_tp>(1L)));

   kernel_shape_.Reshape(spatial_dim_blob_shape);
   int_tp* kernel_shape_data = kernel_shape_.mutable_cpu_data();
@@ -259,7 +259,7 @@ void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       caffe_set(top_count, Dtype(-1), top_mask);
     } else {
       mask = max_idx_.mutable_cpu_data();
-      caffe_set(top_count, -1L, mask);
+      caffe_set(top_count, -1LL, mask);
     }
     caffe_set(top_count, Dtype(-FLT_MAX), top_data);
     // The main loop
@@ -271,8 +271,8 @@ void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
             int_tp wstart = pw * stride_w_ - pad_w_;
             int_tp hend = min(hstart + kernel_h_, height_);
             int_tp wend = min(wstart + kernel_w_, width_);
-            hstart = max(hstart, 0L);
-            wstart = max(wstart, 0L);
+            hstart = max(hstart, static_cast<int_tp>(0L));
+            wstart = max(wstart, static_cast<int_tp>(0L));
             const int_tp pool_index = ph * pooled_width_ + pw;
             for (int_tp h = hstart; h < hend; ++h) {
               for (int_tp w = wstart; w < wend; ++w) {
@@ -314,8 +314,8 @@ void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
             int_tp hend = min(hstart + kernel_h_, height_ + pad_h_);
             int_tp wend = min(wstart + kernel_w_, width_ + pad_w_);
             int_tp pool_size = (hend - hstart) * (wend - wstart);
-            hstart = max(hstart, 0L);
-            wstart = max(wstart, 0L);
+            hstart = max(hstart, static_cast<int_tp>(0L));
+            wstart = max(wstart, static_cast<int_tp>(0L));
             hend = min(hend, height_);
             wend = min(wend, width_);
             for (int_tp h = hstart; h < hend; ++h) {
@@ -406,8 +406,8 @@ void PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
             int_tp hend = min(hstart + kernel_h_, height_ + pad_h_);
             int_tp wend = min(wstart + kernel_w_, width_ + pad_w_);
             int_tp pool_size = (hend - hstart) * (wend - wstart);
-            hstart = max(hstart, 0L);
-            wstart = max(wstart, 0L);
+            hstart = max(hstart, static_cast<int_tp>(0L));
+            wstart = max(wstart, static_cast<int_tp>(0L));
             hend = min(hend, height_);
             wend = min(wend, width_);
             for (int_tp h = hstart; h < hend; ++h) {
diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp
index d5a777b..613677e 100644
--- a/src/caffe/layers/window_data_layer.cpp
+++ b/src/caffe/layers/window_data_layer.cpp
@@ -332,10 +332,10 @@ void WindowDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
         // the extent beyond the image
         int_tp unclipped_height = y2-y1+1;
         int_tp unclipped_width = x2-x1+1;
-        int_tp pad_x1 = std::max(0L, -x1);
-        int_tp pad_y1 = std::max(0L, -y1);
-        int_tp pad_x2 = std::max(0L, x2 - cv_img.cols + 1);
-        int_tp pad_y2 = std::max(0L, y2 - cv_img.rows + 1);
+        int_tp pad_x1 = std::max(static_cast<int_tp>(0L), -x1);
+        int_tp pad_y1 = std::max(static_cast<int_tp>(0L), -y1);
+        int_tp pad_x2 = std::max(static_cast<int_tp>(0L), x2 - cv_img.cols + 1);
+        int_tp pad_y2 = std::max(static_cast<int_tp>(0L), y2 - cv_img.rows + 1);
         // clip bounds
         x1 = x1 + pad_x1;
         x2 = x2 - pad_x2;
diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp
index ab56474..ce1b5d7 100644
--- a/src/caffe/util/hdf5.cpp
+++ b/src/caffe/util/hdf5.cpp
@@ -125,7 +125,7 @@ void hdf5_save_nd_dataset<double>(

 string hdf5_load_string(hid_t loc_id, const string& dataset_name) {
   // Get size of dataset
-  uint_tp size;
+  size_t size;
   H5T_class_t class_;
   herr_t status = \
     H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);

but then I got this one:

src/caffe/util/math_functions.cpp:146:3: error: call to 'cblas_saxpby' is ambiguous
  cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
  ^~~~~~~~~~~~
/usr/local/include/cblas.h:324:6: note: candidate function
void cblas_saxpby(OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST float *x, OPENBLAS_CONST blasint incx,OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy);
     ^
./include/caffe/util/mkl_alternate.hpp:83:13: note: candidate function
inline void cblas_saxpby(const int_tp N, const float alpha, const float* X,
            ^
src/caffe/util/math_functions.cpp:152:3: error: call to 'cblas_daxpby' is ambiguous
  cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
  ^~~~~~~~~~~~
/usr/local/include/cblas.h:326:6: note: candidate function
void cblas_daxpby(OPENBLAS_CONST blasint n, OPENBLAS_CONST double alpha, OPENBLAS_CONST double *x, OPENBLAS_CONST blasint incx,OPENBLAS_CONST double beta, double *y, OPENBLAS_CONST blasint incy);
     ^
./include/caffe/util/mkl_alternate.hpp:89:13: note: candidate function
inline void cblas_daxpby(const int_tp N, const double alpha, const double* X,
            ^
2 errors generated.
make: *** [.build_release/src/caffe/util/math_functions.o] Error 1
letalvoj commented 8 years ago

Never mind, i found this thread, which explains everything. :)

naibaf7 commented 8 years ago

So did you get it to work now? What OS, backend and libraries are you using? From what the problem looks like, it's OS X? Probably an incomatible libstdc++/clang++.

letalvoj commented 8 years ago

No I did not, but I thought that according to the thread the build was broken.

I have a Macbook Pro (late 2014) with Intel Iris Pro GPU. I do not use anything unusual:

I am able to compile the original caffe, after setting the proper paths to python, without any problem. But as you can see then I get weird semantic errors on this branch.

If a try to compile the latest master as it is then I get this error.

edit: libstdc++ ?? let's see...

letalvoj commented 8 years ago

libstdc++ is being set only if the version of CUDA is lower than 7. Forcing it manually just causes another error.

I do not think that it matters, but my Makefile.config is: http://pastebin.com/99Zk7m7b

naibaf7 commented 8 years ago

Well, I recently switched my branch to use 64 bit indexing on everything. That might cause issue for OS X, which I personally do not have and can not test. Would be nice if we could resolve that somehow; meanwhile, try using a commit that is a bit older. For example 80d045263f26c41a1e886906a30d649a5c812038 could work for you.

letalvoj commented 8 years ago

Oh thanks! :) Btw. some of the casting errors were: ‘Could not cast long long to long’. If it is it, are not the longs long enough? Already 64bits…

On 02 Nov 2015, at 15:30, Fabian Tschopp notifications@github.com wrote:

Well, I recently switched my branch to use 64 bit indexing on everything. That might cause issue for OS X, which I personally do not have and can not test. Would be nice if we could resolve that somehow; meanwhile, try using a commit that is a bit older. For example 80d0452 https://github.com/naibaf7/caffe/commit/80d045263f26c41a1e886906a30d649a5c812038 could work for you.

— Reply to this email directly or view it on GitHub https://github.com/naibaf7/caffe/issues/13#issuecomment-153031494.

letalvoj commented 8 years ago

Ok, I give up.

I've already spend too much time trying to compile this stuff. After compiling the version you proposed it crashes during runtime. It seems to be the same behaviour as if I try to use a GPU with CPU only compiled binaries. Also the older Makefile does not allow disabling compilation of the CUDA backend, so it takes a looooooong time to compile.

naibaf7 commented 8 years ago

@letalvoj Ok, the older Makefile from the commit version I noted above does allow to disable the backends the same way actually. As I said before, since I do not have a computer with OS X I can't really support it. My test systems use Ubuntu 14.04 and Fedora 22. If you still consider to get it working anyways, the only thing I can suggest is contacting one of those here who got it working on OS X, for example @lavania @paulbroyles @JacksonGariety

cepiross commented 8 years ago

@naibaf7 Hi, I resolved OS X support issue at https://github.com/BVLC/caffe/pull/3720 Do you want to make a pull request in your repository? Please let me know :)

Thanks.

naibaf7 commented 8 years ago

@cepiross I'll regression test this and merge into both branches, thanks. No need for another pull request.

naibaf7 commented 8 years ago

Fix is merged.