kyegomez / MultiModalMamba

A novel implementation of fusing ViT with Mamba into a fast, agile, and high performance Multi-Modal Model. Powered by Zeta, the simplest AI framework ever.
https://discord.gg/GYbXvDGevY
MIT License
429 stars 23 forks source link

[TEST][BUG][DOCS] Readme demo - ViT resizing #3

Closed evelynmitchell closed 8 months ago

evelynmitchell commented 8 months ago

The second demo in the readme is the Ready to Train Model:

import torch  # Import the torch library

# Import the MMM model from the mm_mamba module
from mm_mamba.model import MMM

# Generate a random tensor 'x' of size (1, 224) with random elements between 0 and 10000
x = torch.randint(0, 10000, (1, 224))

# Generate a random image tensor 'img' of size (1, 3, 224, 224)
img = torch.randn(1, 3, 224, 224)

# Create a MMM model object with the following parameters:
model = MMM(
    vocab_size=10000,
    dim=512,
    depth=6,
    dropout=0.1,
    heads=8,
    d_state=512,
    image_size=224,
    patch_size=16,
    encoder_dim=512,
    encoder_depth=6,
    encoder_heads=8,
    fusion_method="mlp",
    return_embeddings=False,
)

# Pass the tensor 'x' and 'img' through the model and store the output in 'out'
out = model(x, img)

# Print the shape of the output tensor 'out'
print(out.shape)

When I do a pip install from github in colab, and run this, it fails with a dimension mismatch, which, I believe is related to the ViT scaling of the original image (224) down to the floor of the patch size (224 // 16) and back up ^^2 The math is:

224//16 = 14
14 * 14 = 196

The error is that 196 is not 224.

I believe this would be fixed by using a text input length of 196.

import torch  # Import the torch library

# Import the MMM model from the mm_mamba module
from mm_mamba.model import MMM

# Generate a random tensor 'x' of size (1, 196) with random elements between 0 and 10000
# 196 is (224//patchsize16)**2
x = torch.randint(0, 10000, (1, 196))

# Generate a random image tensor 'img' of size (1, 3, 224, 224)
img = torch.randn(1, 3, 224, 224)

# Create a MMM model object with the following parameters:
model = MMM(
    vocab_size=10000,
    dim=512,
    depth=6,
    dropout=0.1,
    heads=8,
    d_state=512,
    image_size=224,
    patch_size=16,
    encoder_dim=512,
    encoder_depth=6,
    encoder_heads=8,
    fusion_method="mlp",
    return_embeddings=False,
)

# Pass the tensor 'x' and 'img' through the model and store the output in 'out'
out = model(x, img)

# Print the shape of the output tensor 'out'
print(out.shape)

README fix incoming.

github-actions[bot] commented 8 months ago

Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.

kyegomez commented 8 months ago

@evelynmitchell good catch, i don't know what I did to introduce this error. I tested it rigorously before pushing

evelynmitchell commented 8 months ago

You're welcome Kye