google-research / deep_representation_one_class

Apache License 2.0
156 stars 27 forks source link

TypeError: Expected int32 passed to parameter 'shape' of op 'ScatterNd', got [3, None, 64, 64, 3] #4

Closed Leggin closed 3 years ago

Leggin commented 3 years ago

Hi, I am currently having problem trying to replicate the results from celeba. I am running the run_contrastive.sh with these Settings:

DATA=celeba
METHOD=Contrastive
SEED=1
CATEGORY=Eyeglasses
MODEL_DIR='.'
python train_and_eval_loop.py \
  --model_dir="${MODEL_DIR}" \
  --method=${METHOD} \
  --file_path="${DATA}_${PREFIX}_s${SEED}_c${CATEGORY}" \
  --dataset=${DATA} \
  --category=${CATEGORY} \
  --seed=${SEED} \
  --root='' \
  --net_type=ResNet18 \
  --net_width=1 \
  --latent_dim=0 \
  --aug_list="cnr0.5+hflip+jitter_b0.4_c0.4_s0.4_h0.4+gray0.2+blur0.5,+" \
  --aug_list_for_test="x" \
  --input_shape="64,64,3" \
  --optim_type=sgd \
  --sched_type=cos \
  --learning_rate=0.01 \
  --momentum=0.9 \
  --weight_decay=0.0003 \
  --head_dims="512,512,512,512,512,512,512,512,128" \
  --num_epoch=2048 \
  --batch_size=32 \
  --temperature=0.2 \
  --distaug_type 1

the error I get is this:

/test/contrastive.py:87 step_fn  *
        y = self.get_target_labels(
    /test/contrastive.py:69 get_target_labels  *
        x_concat = self.cross_replica_concat(x, replica_context=replica_context)
    /test/util/train.py:792 cross_replica_concat  *
        ext_tensor = tf.scatter_nd(
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py:8856 scatter_nd  **
        "ScatterNd", indices=indices, updates=updates, shape=shape, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:479 _apply_op_helper
        repr(values), type(values).__name__, err))

    TypeError: Expected int32 passed to parameter 'shape' of op 'ScatterNd', got [3, None, 64, 64, 3] of type 'list' instead. Error: Expected int32, got None of type 'NoneType' instead.

can you help me fix this?

Leggin commented 3 years ago

the fix is to change tensorflow-gpu to a newer version, I used 2.4.0

rkakash59 commented 2 years ago

I am getting same error with tensorflow-gpu==2.4.0.