sarafridov / K-Planes

Other
486 stars 46 forks source link

question about the effect of the T dimension #19

Closed wangyuyy closed 1 year ago

wangyuyy commented 1 year ago

I tried to verify the effect of the T dimension to the result. Therefore I fixed T dimension values of the girds' parameters(pretrained parameters for both Kplane grids and proposal_network grids) after t0 to find out whether the scene can be stopped at t0. The ideal results should be that the scene mantains the states at t0. But the results were not ideal as showed in the video attactments, in which one stopped at beginning, the other stopped at half time. I wonder whether there are other time-relevent parts I missed. The raw output videos have some issues about playing on website, so please download it to play.

https://drive.google.com/drive/folders/1rjbIvXhAksc1KftudVy0HHSAZr4bhyIx?usp=share_link

sarafridov commented 1 year ago

Let me first verify what you're doing to make sure I understand (let me know if this is incorrect): You are using a pretrained model, either from our Google Drive or from your own model training. You take the saved space-time grids for both the main model and the proposal model, and you replace all values after t0 with the values at t0, along the time dimension of each grid. Is that right? If so then I would expect it should work...

Another method would be to clip the time values associated to each ray to be in [0, t0], and either leave the model as-is or remove all the values after t0 in the time dimension of each grid.

wangyuyy commented 1 year ago

Sorry for unclear expression. I did exactly the same thing you expected. I operated the pre-trained model params using the function below before trainer.load_model(model). And then I used the --spacetime-only to generate the output. But the results didn't meet the expectations as showed in video files above. The other two methods you proposed are clearly effective, however I still want to find out why it didn't work. Thanks a lot for your help!

def operate_model(model, target, operation: str):
    model_param = model['model']
    if operation == 'stop_half_time':
        for reso in range(4):
            for plane in [2, 4, 5]:
                resolutions = model_param[f'field.grids.{reso}.{plane}'].shape[-2]
                slice = int(resolutions / 2)
                for i in range(slice):
                    model_param[f'field.grids.{reso}.{plane}'][:, :, -slice + i, :] = model_param[f'field.grids.{reso}.{plane}'][:, :, -slice - 1, :]
                if reso < 2:
                    for i in range(slice):
                        model_param[f'proposal_networks.{reso}.grids.{plane}'][:, :, -slice + i, :] = model_param[f'proposal_networks.{reso}.grids.{plane}'][:, :, -slice - 1, :]
    return model
sarafridov commented 1 year ago

Is the time resolution the same for the main model and the proposal model? If not (eg if the proposal model is lower-resolution in time) then this code could cause the proposal model to use effectively a later cutoff compared to the main model. It'd be safer to set the slice value separately for the main model and the proposal model.

Another thing to check is just to make sure that you are doing the slicing in the right dimension (the time dimension). From the initialization code (https://github.com/sarafridov/K-Planes/blob/main/plenoxels/models/kplane_field.py#L39) if I remember correctly then time would be the last dimension (rather than second-to-last), but you can double check by just printing the shape of each grid and comparing it to the resolution in your config.

wangyuyy commented 1 year ago

Thanks for your advice! I didn't notice that the number of the time dimensions of the proposal_networks is different from the number of kplanes' time dimension. Many thanks again for your generous help. By the way, according to the initialization code, the time dimension should be the second-to-last dimension (due to coo_comb[::-1]]). Below are the experiments code according to the config file and the codes in kplane_field.py. The result is consistent with the model file as below, where lego=torch.load('<file_path>').

image image