clinicalml / cfrnet

Counterfactual Regression
MIT License
295 stars 82 forks source link

Error when using wass #4

Closed siyuanzhao closed 1 year ago

siyuanzhao commented 6 years ago

Thanks for sharing your code and excellent work!

I try to run your code on other experimental data. However, I run into an error when using wass. The code works when other imb_funs (mmd2_rbf) are used.

Attached is the error log. Do you have any idea about what might cause this error so that I could fix it?

Traceback (most recent call last): File "cfr_net_train.py", line 427, in main run(outdir) File "cfr_net_train.py", line 374, in run D_exp_test, logfile, i_exp) File "cfr_net_train.py", line 142, in train CFR.r_alpha: FLAGS.p_alpha, CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated}) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 766, in run run_metadata_ptr) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 964, in _run feed_dict_string, options, run_metadata) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1014, in _do_run target_list, options, run_metadata) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1034, in _do_call raise type(e)(node_def, op, message) InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [0,10] vs. shape[1] = [1,1] [[Node: concat_2 = Concat[N=2, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](concat_2/concat_dim, concat_1, concat)]] [[Node: gradients/Cast_2/_111 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_588_gradients/Cast_2", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Caused by op u'concat_2', defined at: File "cfr_net_train.py", line 434, in tf.app.run() File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 43, in run sys.exit(main(sys.argv[:1] + flags_passthrough)) File "cfr_net_train.py", line 427, in main run(outdir) File "cfr_net_train.py", line 283, in run CFR = cfr.cfrnet(x, t, y, p, FLAGS, r_alpha, r_lambda, do_in, do_out, dims) File "/home/siyuan/git/cfrnet/cfr/cfr_net.py", line 25, in init self._buildgraph(x, t, y , p_t, FLAGS, r_alpha, r_lambda, do_in, do_out, dims) File "/home/siyuan/git/cfrnet/cfr/cfr_net.py", line 185, in _build_graph imb_dist, imb_mat = wasserstein(h_rep_norm,t,p_ipm,lam=FLAGS.wass_lambda,its=FLAGS.wass_iterations,sq=False,backpropT=FLAGS.wass_bpt) File "/home/siyuan/git/cfrnet/cfr/util.py", line 230, in wasserstein Mt = tf.concat(1,[Mt,col]) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1080, in concat name=name) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 438, in _concat values=values, name=name) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 759, in apply_op op_def=op_def) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2240, in create_op original_op=self._default_original_op, op_def=op_def) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1128, in init self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [0,10] vs. shape[1] = [1,1] [[Node: concat_2 = Concat[N=2, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](concat_2/concat_dim, concat_1, concat)]] [[Node: gradients/Cast_2/_111 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_588_gradients/Cast_2", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

nlokeshiisc commented 2 years ago

The reported issue might be because of a mismatch in the tf version. I used tf version=1.14 and the API seems to be broken. I fixed the code by making the following changes in the wasserstein() function:

col = tf.concat([delta*tf.ones(tf.shape(M[:,0:1])),tf.zeros((1,1))], axis=0)
Mt = tf.concat([M,row], axis=0)
Mt = tf.concat([Mt,col], axis=1)
a = tf.concat([p*tf.ones(tf.shape(tf.where(t>0)[:,0:1]))/nt, (1-p)*tf.ones((1,1))], axis=0)
b = tf.concat([(1-p)*tf.ones(tf.shape(tf.where(t<1)[:,0:1]))/nc, p*tf.ones((1,1))], axis=0)