allenai / Holodeck

CVPR 2024: Language Guided Generation of 3D Embodied AI Environments.
https://yueyang1996.github.io/holodeck
Apache License 2.0
304 stars 25 forks source link

How to generate clip and sbert feature for a object? #15

Closed Dancing-Github closed 6 months ago

Dancing-Github commented 6 months ago

Very impressive work! I'm trying to add some custom objects to the inference pipeline, but stuck in the step of calculating similarity between query and object. I want to know more about the calculation of object features.

What specific models to use for CLIP and SBERT respectively? What's the input of CLIP? Is it a image of the object, or the image from multiple views? What's the input of SBERT? It should contain the annotations from GPT. What else is needed and what's the input text template?

Could you explain the questions above? Or further release the corresponding code? Looking forward to your reply.

YueYANG1996 commented 6 months ago

The CLIP and SBERT are here: https://github.com/allenai/Holodeck/blob/7094a014acc6ad0c8ddb9a739e3760940e1f3e21/modules/holodeck.py#L33 https://github.com/allenai/Holodeck/blob/7094a014acc6ad0c8ddb9a739e3760940e1f3e21/modules/holodeck.py#L37

For the CLIP feature, I took the 3 views of the object, which are 0 degrees, 45 degrees, and -45 degrees. For SBERT, we use the text description of the asset as input. Please check Figure 16 and 19 of the paper for details.

YueYANG1996 commented 4 months ago

To get the screenshots of an asset, you could use the following code:

import os
import sys
import math
import json
import copy
from PIL import Image
from tqdm import tqdm
import ai2thor.controller
import ai2thor.fifo_server
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetHookRunner

instance_id = "asset_0"
width = 512
height = 512
empty_house = json.load(open("../modules/empty_house.json", "r"))

def make_single_object_house(asset_id, asset_dimension, rotation, house):
    house["objects"] = [
        {
            "assetId": asset_id,
            "id": instance_id,
            "kinematic": True,
            "position": {"x": 0, "y": asset_dimension["y"]/2, "z": 0},
            "rotation": {"x": rotation[0], "y": rotation[1], "z": rotation[2]},
            "layer": "Procedural0",
            "material": None,
        }
    ]
    return house

def init_controller():
    controller = ai2thor.controller.Controller(
        start_unity=True,
        scene="Procedural",
        makeAgentsVisible=False,
        gridSize=0.25,
        width=width,
        height=height,
        server_class=ai2thor.fifo_server.FifoServer,
        action_hook_runner=ProceduralAssetHookRunner(
            asset_directory=assetFolder,
            asset_symlink=True,
            verbose=True,
        ),
    )
    return controller

def view_asset_in_thor(asset_id, asset_dimension, controller, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    rotations = [(30, 0, 0), (30, 45, 30), (0, 90, 30),
                 (-30, 135, 30), (-30, 180, 0), (-30, 225, -30), 
                 (0, 270, -30), (30, 315, -30), (90, 0, 0)]

    views = [0, 45, 90, 135, 180, 225, 270, 315, "top"]

    house = make_single_object_house(asset_id, asset_dimension, rotations[0], copy.deepcopy(empty_house))

    controller.reset(scene="Procedural")
    evt = controller.step(action="CreateHouse", house=house)
    evt = controller.step(action="LookAtObjectCenter", objectId=instance_id)
    evt = controller.step(
        action="SetSkybox", 
        color={
            "r": 1,
            "g": 1,
            "b": 1,
        }
    )

    for i, rotation in enumerate(rotations):
        evt = controller.step(
            action="RotateObject",
            angleAxisRotation={
                "x": rotation[0],
                "y": rotation[1],
                "z": rotation[2],
            },
            absolute=True
        )

        im = Image.fromarray(evt.frame)
        im.save(os.path.join(output_dir, f"{views[i]}.png"))

    return evt

def main():
    output_dir = f"../data/objaverse_plus/{version}/asset_screenshot"
    database = json.load(open(f"../data/objaverse_plus/{version}/objaverse_plus_valid_database.json", "r"))
    controller = init_controller()

    count = 0
    for asset_id in tqdm(database):
        count += 1
        if count == 500:
            controller.stop()
            controller = init_controller()
            count = 0
        try:
            asset_dimension = database[asset_id]["assetMetadata"]["boundingBox"]
            view_asset_in_thor(asset_id, asset_dimension, controller, os.path.join(output_dir, asset_id))
        except:
            print(f"Failed to render {asset_id}")
            controller.stop()
            controller = init_controller()
            count = 0
            continue

if __name__ == "__main__":
    global assetFolder
    global version   
    assetFolder = sys.argv[1]
    version = sys.argv[2]
    main()