robodhruv / visualnav-transformer

Official code and checkpoint release for mobile robot foundation models: GNM, ViNT, and NoMaD.
http://general-navigation-models.github.io
MIT License
425 stars 56 forks source link

Batch Size Mismatch in explore.py #20

Closed chungcode0218 closed 2 months ago

chungcode0218 commented 3 months ago

Hello,

I've been working with your project and examining the handling of model inputs in explore.py and navigate.py, especially in the context of simulations using CARLA, where I receive RGBA data. I preprocess this data to convert it from RGBA to RGB to align with the expected input format of your scripts.

In navigate.py, I noticed a specific approach to adjusting the batch size for the observed images (obs_img) and goal images (goal_img) before feeding them into the model:

obsgoal_cond = model('vision_encoder', obs_img=obs_images.repeat(len(goal_image), 1, 1, 1), goal_img=goal_image, input_goal_mask=mask.repeat(len(goal_image)))

Here, both obs_images and mask are adjusted in batch size to match that of goal_image by using .repeat(len(goal_image), 1, 1, 1), ensuring dimensional consistency for safe model forward propagation.

However, in explore.py, a similar strategy doesn't seem to be applied when dealing with comparable inputs:

obs_cond = model('vision_encoder', obs_img=obs_images, goal_img=fake_goal, input_goal_mask=mask)

In my experiments, since the batch size for fake_goal is defaulted to 1, and obs_images might have a different batch size, it leads to a batch size mismatch error when attempting to combine them for model input.

Could you shed some light on the rationale behind the handling in explore.py? Is there a specific reason for not adjusting the batch size of fake_goal to match obs_images as done in navigate.py? And would you recommend a particular approach to address the batch size mismatch issue I encountered in explore.py within the CARLA simulation context?

Thank you very much for your time and assistance. I look forward to your response.

image

yzl123 commented 3 months ago
topomap = []
for i in range(num_nodes):
    image_path = os.path.join(topomap_dir, topomap_filenames[i])
    topomap.append(PILImage.open(image_path))

& def msg_to_pil(msg: Image) -> PILImage.Image: img = np.frombuffer(msg.data, dtype=np.uint8).reshape( msg.height, msg.width, -1) pil_image = PILImage.fromarray(img) return pil_image update: topomap = [] for i in range(num_nodes): image_path = os.path.join(topomap_dir, topomap_filenames[i]) topomap.append(PILImage.open(image_path).convert('RGB')) & def msg_to_pil(msg: Image) -> PILImage.Image: img = np.frombuffer(msg.data, dtype=np.uint8).reshape( msg.height, msg.width, -1) pil_image = PILImage.fromarray(img).convert('RGB') return pil_image