jonbarron / camp_zipnerf

Apache License 2.0
661 stars 41 forks source link

flax.errors.ScopeParamShapeError at executing 360_eval.sh or 360_render.sh after executing camp/360_train.sh #23

Closed ppponpon closed 4 months ago

ppponpon commented 5 months ago

Hi, thanks for this great paper and code. After running camp/360_train.sh, I am trying to run 360_eval.sh or 360_render.sh, a flax.errors.ScopeParamShapeError always occur. Is there a solution? Below is the console log when running 360_render.sh.

` 2024-04-25 05:47:31.916448: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-04-25 05:47:31.916484: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-04-25 05:47:31.917323: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-04-25 05:47:32.577797: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT I0425 05:47:41.449953 140104660227264 render.py:144] Rendering config:

Parameters for eval/Config:

==============================================================================

eval/Config.adam_beta1 = 0.9 eval/Config.adam_beta2 = 0.99 eval/Config.adam_eps = 1e-15 eval/Config.arcore_format_pose_file = None eval/Config.autoexpose_renders = False eval/Config.batch_size = 8192 eval/Config.batching = 'all_images' eval/Config.camera_perturb_dolly_use_average = False eval/Config.camera_perturb_intrinsic_single = True eval/Config.camera_perturb_sigma_dolly_z = 0.0 eval/Config.camera_perturb_sigma_focal_length = 0.0 eval/Config.camera_perturb_sigma_look_at = 0.0 eval/Config.camera_perturb_sigma_position = 0.0 eval/Config.camera_perturb_zero_distortion = False eval/Config.cast_rays_in_eval_step = True eval/Config.cast_rays_in_train_step = True eval/Config.charb_padding = 0.001 eval/Config.checkpoint_dir = \ '/home/user/camp_zipnerf_output/zipnerf/360/garden' eval/Config.checkpoint_every = 10000 eval/Config.checkpoint_init = False eval/Config.checkpoint_keep = 2 eval/Config.colmap_subdir = None eval/Config.compute_disp_metrics = False eval/Config.compute_normal_metrics = False eval/Config.compute_procrustes_metric = False eval/Config.data_coarse_loss_mult = 0.0 eval/Config.data_dir = '/home/user/data/360_v2/garden' eval/Config.data_loss_mult = 1.0 eval/Config.data_loss_type = 'charb' eval/Config.dataset_loader = 'llff' eval/Config.debug_mode = False eval/Config.deterministic_showcase = True eval/Config.disable_multiscale_loss = False eval/Config.disable_pmap_and_jit = False eval/Config.distortion_loss_curve_fn = \ (@math.power_ladder, {'p': -0.25, 'premult': 10000.0}) eval/Config.distortion_loss_mult = 0.01 eval/Config.distortion_loss_target = 'tdist' eval/Config.donate_args_to_train = True eval/Config.dtu_light_cond = 3 eval/Config.early_exit_steps = None eval/Config.eikonal_coarse_loss_mult = 0.0 eval/Config.eikonal_loss_mult = 0.0 eval/Config.enable_grid_c2f = False eval/Config.enable_loss_scaler = False eval/Config.eval_crop_borders = 0 eval/Config.eval_dataset_limit = 2147483647 eval/Config.eval_only_once = True eval/Config.eval_quantize_metrics = True eval/Config.eval_raw_affine_cc = False eval/Config.eval_render_interval = 1 eval/Config.eval_save_output = True eval/Config.eval_save_ray_data = False eval/Config.exposure_percentile = 97.0 eval/Config.factor = 4 eval/Config.far = 1000000.0 eval/Config.far_plane_meters = None eval/Config.focal_length_var_loss_mult = 0.0 eval/Config.forward_facing = False eval/Config.gc_every = 10000 eval/Config.grad_max_norm = 0.0 eval/Config.grad_max_val = 0.0 eval/Config.grid_c2f_weight_method = 'cosine_sequential' eval/Config.image_subdir = None eval/Config.jax_rng_seed = 20200823 eval/Config.llff_load_from_poses_bounds = False eval/Config.llff_use_all_images_for_training = False eval/Config.llffhold = 8 eval/Config.load_alphabetical = True eval/Config.load_colmap_points = False eval/Config.load_ngp_format_poses = False eval/Config.lock_up = False eval/Config.loss_scale = 1000.0 eval/Config.lr_delay_mult = 1e-08 eval/Config.lr_delay_steps = 20000 eval/Config.lr_final = 0.000125 eval/Config.lr_final_grid = None eval/Config.lr_init = 0.00125 eval/Config.lr_init_grid = None eval/Config.max_steps = 200000 eval/Config.multiscale_train_factors = None eval/Config.near = 0.0 eval/Config.near_plane_meters = None eval/Config.np_rng_seed = 20201473 eval/Config.num_border_pixels_to_mask = 0 eval/Config.num_showcase_images = 5 eval/Config.optimize_cameras = False eval/Config.optimize_test_cameras = False eval/Config.optimize_test_cameras_batch_size = 10000 eval/Config.optimize_test_cameras_for_n_steps = 200 eval/Config.optimize_test_cameras_lr = 0.001 eval/Config.orientation_coarse_loss_mult = 0.0 eval/Config.orientation_loss_mult = 0.0 eval/Config.orientation_loss_target = 'normals_pred' eval/Config.param_regularizers = \ {'grid_0': (0.1, @jnp.mean, 2, 1), 'grid_1': (0.1, @jnp.mean, 2, 1), 'grid_2': (0.1, @jnp.mean, 2, 1)} eval/Config.patch_size = 1 eval/Config.predicted_normal_coarse_loss_mult = 0.0 eval/Config.predicted_normal_loss_mult = 0.0 eval/Config.principal_point_reg_loss_mult = 0.0 eval/Config.principal_point_var_loss_mult = 0.0 eval/Config.print_camera_every = 500 eval/Config.print_every = 100 eval/Config.rad_mult_max = 1.0 eval/Config.rad_mult_min = 1.0 eval/Config.radial_distortion_var_loss_mult = 0.0 eval/Config.randomized = True eval/Config.rawnerf_mode = False eval/Config.render_calibration_distance = 3.0 eval/Config.render_calibration_keyframes = None eval/Config.render_camtype = None eval/Config.render_chunk_size = 32768 eval/Config.render_delete_images_when_done = True eval/Config.render_dir = \ '/home/user/camp_zipnerf_output/zipnerf/360/garden/render/' eval/Config.render_dist_adaptive = False eval/Config.render_dist_percentile = 0.5 eval/Config.render_focal = None eval/Config.render_looped_videos = False eval/Config.render_path = True eval/Config.render_path_file = None eval/Config.render_path_frames = 480 eval/Config.render_resolution = None eval/Config.render_rgb_only = False eval/Config.render_rotate_xaxis = 0.0 eval/Config.render_rotate_yaxis = 0.0 eval/Config.render_spherical = False eval/Config.render_spline_const_speed = False eval/Config.render_spline_degree = 5 eval/Config.render_spline_fixed_up = False eval/Config.render_spline_interpolate_exposure = False eval/Config.render_spline_interpolate_exposure_smoothness = 20 eval/Config.render_spline_keyframes = None eval/Config.render_spline_keyframes_choices = None eval/Config.render_spline_lock_up = False eval/Config.render_spline_lookahead_i = None eval/Config.render_spline_meters_per_sec = None eval/Config.render_spline_n_buffer = None eval/Config.render_spline_n_interp = 30 eval/Config.render_spline_outlier_keyframe_multiplier = None eval/Config.render_spline_outlier_keyframe_quantile = None eval/Config.render_spline_rot_weight = 0.1 eval/Config.render_spline_smoothness = 0.03 eval/Config.render_video_crf = 18 eval/Config.render_video_exts = ('mp4',) eval/Config.render_video_fps = 60 eval/Config.robust_loss_scale = 0.01 eval/Config.save_calibration_to_disk = False eval/Config.scene_bbox = None eval/Config.spline_interlevel_params = {'blurs': (0.03, 0.003), 'mults': 0.01} eval/Config.train_render_every = 0 eval/Config.transform_poses_fn = None eval/Config.use_exrs = False eval/Config.use_identity_cameras = False eval/Config.use_perturbed_cameras = False eval/Config.use_tiffs = False eval/Config.vis_decimate = 0 eval/Config.vis_num_rays = 16 eval/Config.visualize_every = 10000 eval/Config.vocab_tree_path = None eval/Config.world_scale = 1.0 eval/Config.z_max = None eval/Config.z_min = None eval/Config.z_phase = 0.0 eval/Config.z_variation = 0.0

I0425 05:47:41.546334 140104660227264 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA I0425 05:47:41.548914 140104660227264 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory using 4x downsampled images Finding COLMAP data Constructing NeRF Scene Manager Warning: image_path not found for reconstruction Processing COLMAP data Loaded camera parameters for 185 images image names sorted alphabetically Loading images from /home/user/data/360_v2/garden/images_4 Loaded 185 images from disk Loaded EXIF data for 185 images Constructed COLMAP-to-world transform. Constructed 480 render poses via ellipse path Constructed train/test split: #train=161 #test=24 LLFF successfully loaded! split=DataSplit.TEST #images/poses/exposures=24 #camtoworlds=480 * resolution=(840, 1297) I0425 05:49:08.581712 140104660227264 checkpoints.py:1062] Restoring orbax checkpoint from //home/user/camp_zipnerf_output/zipnerf/360/garden/checkpoint_200000 I0425 05:49:08.584176 140104660227264 checkpointer.py:164] Restoring item from /home/user/camp_zipnerf_output/zipnerf/360/garden/checkpoint_200000. W0425 05:49:09.925947 140104660227264 transform_utils.py:229] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on. I0425 05:49:09.927694 140104660227264 transform_utils.py:286] The following keys are not loaded from the original tree after applying specified transforms: opt_state/0/count, opt_state/0/mu/params/MLP_0/Dense_0/bias, opt_state/0/mu/params/MLP_0/Dense_0/kernel, opt_state/0/mu/params/MLP_0/Dense_1/bias, opt_state/0/mu/params/MLP_0/Dense_1/kernel, opt_state/0/mu/params/MLP_1/Dense_0/bias, opt_state/0/mu/params/MLP_1/Dense_0/kernel, opt_state/0/mu/params/MLP_1/Dense_1/bias, opt_state/0/mu/params/MLP_1/Dense_1/kernel, opt_state/0/mu/params/MLP_2/Dense_0/bias, opt_state/0/mu/params/MLP_2/Dense_0/kernel, opt_state/0/mu/params/MLP_2/Dense_1/bias, opt_state/0/mu/params/MLP_2/Dense_1/kernel, opt_state/0/mu/params/MLP_2/Dense_2/bias, opt_state/0/mu/params/MLP_2/Dense_2/kernel, opt_state/0/mu/params/MLP_2/Dense_3/bias, opt_state/0/mu/params/MLP_2/Dense_3/kernel, opt_state/0/mu/params/MLP_2/Dense_4/bias, opt_state/0/mu/params/MLP_2/Dense_4/kernel, opt_state/0/mu/params/MLP_2/Dense_5/bias, opt_state/0/mu/params/MLP_2/Dense_5/kernel, opt_state/0/mu/params/MLP_2/Dense_6/bias, opt_state/0/mu/params/MLP_2/Dense_6/kernel, opt_state/0/mu/params/grid_0/grid_016, opt_state/0/mu/params/grid_0/grid_032, opt_state/0/mu/params/grid_0/grid_064, opt_state/0/mu/params/grid_0/grid_128, opt_state/0/mu/params/grid_0/hash_256, opt_state/0/mu/params/grid_0/hash_512, opt_state/0/mu/params/grid_1/grid_0016, opt_state/0/mu/params/grid_1/grid_0032, opt_state/0/mu/params/grid_1/grid_0064, opt_state/0/mu/params/grid_1/grid_0128, opt_state/0/mu/params/grid_1/hash_0256, opt_state/0/mu/params/grid_1/hash_0512, opt_state/0/mu/params/grid_1/hash_1024, opt_state/0/mu/params/grid_1/hash_2048, opt_state/0/mu/params/grid_2/grid_0016, opt_state/0/mu/params/grid_2/grid_0032, opt_state/0/mu/params/grid_2/grid_0064, opt_state/0/mu/params/grid_2/grid_0128, opt_state/0/mu/params/grid_2/hash_0256, opt_state/0/mu/params/grid_2/hash_0512, opt_state/0/mu/params/grid_2/hash_1024, opt_state/0/mu/params/grid_2/hash_2048, opt_state/0/mu/params/grid_2/hash_4096, opt_state/0/mu/params/grid_2/hash_8192, opt_state/0/nu/params/MLP_0/Dense_0/bias, opt_state/0/nu/params/MLP_0/Dense_0/kernel, opt_state/0/nu/params/MLP_0/Dense_1/bias, opt_state/0/nu/params/MLP_0/Dense_1/kernel, opt_state/0/nu/params/MLP_1/Dense_0/bias, opt_state/0/nu/params/MLP_1/Dense_0/kernel, opt_state/0/nu/params/MLP_1/Dense_1/bias, opt_state/0/nu/params/MLP_1/Dense_1/kernel, opt_state/0/nu/params/MLP_2/Dense_0/bias, opt_state/0/nu/params/MLP_2/Dense_0/kernel, opt_state/0/nu/params/MLP_2/Dense_1/bias, opt_state/0/nu/params/MLP_2/Dense_1/kernel, opt_state/0/nu/params/MLP_2/Dense_2/bias, opt_state/0/nu/params/MLP_2/Dense_2/kernel, opt_state/0/nu/params/MLP_2/Dense_3/bias, opt_state/0/nu/params/MLP_2/Dense_3/kernel, opt_state/0/nu/params/MLP_2/Dense_4/bias, opt_state/0/nu/params/MLP_2/Dense_4/kernel, opt_state/0/nu/params/MLP_2/Dense_5/bias, opt_state/0/nu/params/MLP_2/Dense_5/kernel, opt_state/0/nu/params/MLP_2/Dense_6/bias, opt_state/0/nu/params/MLP_2/Dense_6/kernel, opt_state/0/nu/params/grid_0/grid_016, opt_state/0/nu/params/grid_0/grid_032, opt_state/0/nu/params/grid_0/grid_064, opt_state/0/nu/params/grid_0/grid_128, opt_state/0/nu/params/grid_0/hash_256, opt_state/0/nu/params/grid_0/hash_512, opt_state/0/nu/params/grid_1/grid_0016, opt_state/0/nu/params/grid_1/grid_0032, opt_state/0/nu/params/grid_1/grid_0064, opt_state/0/nu/params/grid_1/grid_0128, opt_state/0/nu/params/grid_1/hash_0256, opt_state/0/nu/params/grid_1/hash_0512, opt_state/0/nu/params/grid_1/hash_1024, opt_state/0/nu/params/grid_1/hash_2048, opt_state/0/nu/params/grid_2/grid_0016, opt_state/0/nu/params/grid_2/grid_0032, opt_state/0/nu/params/grid_2/grid_0064, opt_state/0/nu/params/grid_2/grid_0128, opt_state/0/nu/params/grid_2/hash_0256, opt_state/0/nu/params/grid_2/hash_0512, opt_state/0/nu/params/grid_2/hash_1024, opt_state/0/nu/params/grid_2/hash_2048, opt_state/0/nu/params/grid_2/hash_4096, opt_state/0/nu/params/grid_2/hash_8192, opt_state/1/count I0425 05:49:09.967541 140104660227264 checkpointer.py:166] Finished restoring checkpoint from /home/user/camp_zipnerf_output/zipnerf/360/garden/checkpoint_200000. I0425 05:49:09.968436 140104660227264 render.py:61] Rendering checkpoint at step 200000. /home/user/miniconda3/envs/camp_zipnerf/lib/python3.11/site-packages/jax/_src/xla_bridge.py:945: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. warnings.warn( I0425 05:49:10.007920 140104660227264 render.py:96] Evaluating image 1/480 I0425 05:49:10.008137 140104660227264 models.py:1046] Rendering chunk 1/34 Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/home/user/_NeRF_Test/camp_zipnerf/render.py", line 199, in app.run(main) File "/home/user/miniconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/user/miniconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/home/user/_NeRF_Test/camp_zipnerf/render.py", line 194, in main render_config(config) File "/home/user/_NeRF_Test/camp_zipnerf/render.py", line 155, in render_config render_pipeline(config) File "/home/user/_NeRF_Test/camp_zipnerf/render.py", line 99, in render_pipeline rendering = models.render_image( # pytype: disable=wrong-arg-types # jnp-array ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/_NeRF_Test/camp_zipnerf/internal/models.py", line 1085, in render_image chunkrenderings, = render_fn(rng, chunk_rays) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/_NeRF_Test/camp_zipnerf/internal/train_utils.py", line 770, in render_eval_fn model.apply( File "/home/user/_NeRF_Test/camp_zipnerf/internal/models.py", line 279, in call ray_results = mlp( ^^^^ File "/home/user/_NeRF_Test/camp_zipnerf/internal/models.py", line 779, in call raw_density, x = predict_density(means, covs, predict_density_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/_NeRF_Test/camp_zipnerf/internal/models.py", line 733, in predict_density x = density_dense_layer(self.net_width)(x) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/camp_zipnerf/lib/python3.11/site-packages/flax/linen/linear.py", line 235, in call kernel = self.param( ^^^^^^^^^^^ flax.errors.ScopeParamShapeError: Initializer expected to generate shape (36, 64) but got shape (12, 64) instead for parameter "kernel" in "/MLP_0/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError) For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

`

ppponpon commented 5 months ago

And I tried at another jax and flax version's environment, do the same sh was running,Another error message appeared, That is "ValueError: The field names of the state dict and the named tuple do not match, got {'inner_state'} and {'count', 'mu', 'nu'}". The Checkpoint file of camp/360_train.sh's output isn't match the field of input checkpoint file of 360_eval.sh and 360_render.sh? Do I need to modify camp/360_train.sh or 360_render.sh?

at python=3.10 jax==0.4.6 jaxlib==0.4.6 flax==0.6.1

Below is the console error log when running 360_render.sh. after running camp/360_train.sh. Traceback (most recent call last): File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/user/_NeRF_Test/_google-research/google-research/camp_zipnerf/render.py", line 199, in app.run(main) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/user/_NeRF_Test/_google-research/google-research/camp_zipnerf/render.py", line 194, in main render_config(config) File "/home/user/_NeRF_Test/_google-research/google-research/camp_zipnerf/render.py", line 155, in render_config render_pipeline(config) File "/home/user/_NeRF_Test/_google-research/google-research/camp_zipnerf/render.py", line 59, in render_pipeline state = checkpoints.restore_checkpoint(config.checkpoint_dir, state) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/flax/training/checkpoints.py", line 752, in restore_checkpoint return serialization.from_state_dict(target, state_dict) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/flax/serialization.py", line 65, in from_state_dict return ty_from_state_dict(target, state) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/flax/struct.py", line 149, in from_state_dict updates[name] = serialization.from_state_dict(value, value_state) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/flax/serialization.py", line 65, in from_state_dict return ty_from_state_dict(target, state) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/flax/serialization.py", line 156, in lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict))) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/flax/serialization.py", line 114, in _restore_list y = from_state_dict(xs[i], state_dict[str(i)]) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/flax/serialization.py", line 65, in from_state_dict return ty_from_state_dict(target, state) File "/home/user/miniconda3/envs/camp_zipnerf_gr/lib/python3.10/site-packages/flax/serialization.py", line 146, in _restore_namedtuple raise ValueError('The field names of the state dict and the named tuple do not match,' ValueError: The field names of the state dict and the named tuple do not match, got {'inner_state'} and {'count', 'mu', 'nu'}.

frasiolas commented 4 months ago

I had the same problem flax.errors.ScopeParamShapeError: Initializer expected to generate shape (36, 64) but got shape (12, 64) instead for parameter "kernel" in "/MLP_0/Dense_0".

The solution was to load the correct config file from the model that i trained.