Closed siyuanzhao closed 1 year ago
The reported issue might be because of a mismatch in the tf
version. I used tf version=1.14
and the API seems to be broken.
I fixed the code by making the following changes in the wasserstein()
function:
col = tf.concat([delta*tf.ones(tf.shape(M[:,0:1])),tf.zeros((1,1))], axis=0)
Mt = tf.concat([M,row], axis=0)
Mt = tf.concat([Mt,col], axis=1)
a = tf.concat([p*tf.ones(tf.shape(tf.where(t>0)[:,0:1]))/nt, (1-p)*tf.ones((1,1))], axis=0)
b = tf.concat([(1-p)*tf.ones(tf.shape(tf.where(t<1)[:,0:1]))/nc, p*tf.ones((1,1))], axis=0)
Thanks for sharing your code and excellent work!
I try to run your code on other experimental data. However, I run into an error when using wass. The code works when other imb_funs (mmd2_rbf) are used.
Attached is the error log. Do you have any idea about what might cause this error so that I could fix it?
Traceback (most recent call last): File "cfr_net_train.py", line 427, in main run(outdir) File "cfr_net_train.py", line 374, in run D_exp_test, logfile, i_exp) File "cfr_net_train.py", line 142, in train CFR.r_alpha: FLAGS.p_alpha, CFR.r_lambda: FLAGS.p_lambda, CFR.p_t: p_treated}) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 766, in run run_metadata_ptr) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 964, in _run feed_dict_string, options, run_metadata) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1014, in _do_run target_list, options, run_metadata) File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1034, in _do_call raise type(e)(node_def, op, message) InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [0,10] vs. shape[1] = [1,1] [[Node: concat_2 = Concat[N=2, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](concat_2/concat_dim, concat_1, concat)]] [[Node: gradients/Cast_2/_111 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_588_gradients/Cast_2", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
Caused by op u'concat_2', defined at: File "cfr_net_train.py", line 434, in
tf.app.run()
File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 43, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "cfr_net_train.py", line 427, in main
run(outdir)
File "cfr_net_train.py", line 283, in run
CFR = cfr.cfrnet(x, t, y, p, FLAGS, r_alpha, r_lambda, do_in, do_out, dims)
File "/home/siyuan/git/cfrnet/cfr/cfr_net.py", line 25, in init
self._buildgraph(x, t, y , p_t, FLAGS, r_alpha, r_lambda, do_in, do_out, dims)
File "/home/siyuan/git/cfrnet/cfr/cfr_net.py", line 185, in _build_graph
imb_dist, imb_mat = wasserstein(h_rep_norm,t,p_ipm,lam=FLAGS.wass_lambda,its=FLAGS.wass_iterations,sq=False,backpropT=FLAGS.wass_bpt)
File "/home/siyuan/git/cfrnet/cfr/util.py", line 230, in wasserstein
Mt = tf.concat(1,[Mt,col])
File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1080, in concat
name=name)
File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 438, in _concat
values=values, name=name)
File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 759, in apply_op
op_def=op_def)
File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2240, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/home/siyuan/.virtualenvs/tf012/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1128, in init
self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [0,10] vs. shape[1] = [1,1] [[Node: concat_2 = Concat[N=2, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](concat_2/concat_dim, concat_1, concat)]] [[Node: gradients/Cast_2/_111 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_588_gradients/Cast_2", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]