aimagelab / show-control-and-tell

Show, Control and Tell: A Framework for Generating Controllable and Grounded Captions. CVPR 2019
https://arxiv.org/abs/1811.10652
BSD 3-Clause "New" or "Revised" License
282 stars 61 forks source link

Training of Sinkhorn Operator and Data Defination #21

Open lindatan90 opened 4 years ago

lindatan90 commented 4 years ago

Hi, may I know is the code for training the Sinkhorn network available? Currently, I found that only test_region_set.py has uses pretrained Sinkhorn Network, but I'm interested to know how it was trained from scratch.

Also, I'm kinda confuse about all the data loaded into project. I'll write down my understanding regarding the data below and please correct me if I'm wrong.

NguyenVanThanhHust commented 3 years ago

Hi, i'm confused too.

Hi @lindatan90 From what i read, train.py line 104, captions shape should be (bs, 20). I printed.

For detection: detections.shape: torch.Size([bs, 100, 2048]: first dim is batch size, second dim: number of feature vector, last dim is dim of each feature vector-> there 100 feature vector of shape (2048)

For caption: shape should be [bs, 20]: this is encoded from dictionary: it should be some thing like: tensor([[ 2, 4, 334, 53, 577, 98, 483, 10, 274, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')

For ctrl_det_seqs: shape should be [bs, 20, 20, 2048], 1st dim is batch_size, 2nd dim is time step, last 2 dims is just list of feature vector. This is a bit confuse to me at first time. But you can think in this way. They detect 20 object, get feature vector of that 20 object, got [20x2048], duplicate 20 times for get 20 step, got [20x20x2048] => then add batch size

NguyenVanThanhHust commented 3 years ago

You might want to read model_file

We extract image descriptor by get mask of detection tensor, get average of feature, concatenate [embedded word, image descriptor] to create inputs.

seqs[1] = ctrl_dets_seqs This is list of detected feature of 20 objects det_curr = seqs[1][:, t] # state 2 if you add line: print(torch.equal(seqs[1][:, t], seqs[1][:, t+1])) it will output True,True, ... error (index of out bound) In case False, you can print and verify it's just numerical error.

NguyenVanThanhHust commented 3 years ago

For Sinkhorn operator training, you could read their paper and write your own code.