I'm trying to use a pre-trained XCiT trunk to train 3 parallel heads for a multi-output image classification problem. Basically, I have images that need to be classified across three categories (so each image will receive three labels--one from each head).
I have set up my configuration (pasted at the bottom) to contain three mlp heads for three, four, and three label classes, respectively. When I try to run a forward training pass, I get this error:
Traceback (most recent call last):
File "/home/mbarna/Projects/mldevautomation/run_pipeline.py", line 35, in <module>
main()
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/main.py", line 32, in decorated_main
_run_hydra(
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/utils.py", line 346, in _run_hydra
run_and_report(
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/utils.py", line 201, in run_and_report
raise ex
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/utils.py", line 198, in run_and_report
return func()
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/utils.py", line 347, in <lambda>
lambda: hydra.run(
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 107, in run
return run_job(
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/core/utils.py", line 129, in run_job
ret.return_value = task_function(task_cfg)
File "/home/mbarna/Projects/mldevautomation/run_pipeline.py", line 30, in main
summary = pipeline.execute()
File "/home/mbarna/Projects/mldevautomation/pipelines/training.py", line 49, in execute
request = self.trainer.handle(request=request)
File "/home/mbarna/Projects/mldevautomation/trainers/vissl_runner.py", line 42, in handle
launch_distributed(
File "/home/mbarna/Projects/vissl/vissl/utils/distributed_launcher.py", line 164, in launch_distributed
raise e
File "/home/mbarna/Projects/vissl/vissl/utils/distributed_launcher.py", line 150, in launch_distributed
_distributed_worker(
File "/home/mbarna/Projects/vissl/vissl/utils/distributed_launcher.py", line 192, in _distributed_worker
run_engine(
File "/home/mbarna/Projects/vissl/vissl/engines/engine_registry.py", line 86, in run_engine
engine.run_engine(
File "/home/mbarna/Projects/vissl/vissl/engines/train.py", line 39, in run_engine
train_main(
File "/home/mbarna/Projects/vissl/vissl/engines/train.py", line 130, in train_main
trainer.train()
File "/home/mbarna/Projects/vissl/vissl/trainer/trainer_main.py", line 211, in train
raise e
File "/home/mbarna/Projects/vissl/vissl/trainer/trainer_main.py", line 193, in train
task = train_step_fn(task)
File "/home/mbarna/Projects/vissl/vissl/trainer/train_steps/standard_train_step.py", line 143, in standard_train_step
model_output = task.model(sample["input"])
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 97, in __call__
return self.forward(*args, **kwargs)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 111, in forward
out = self.classy_model(*args, **kwargs)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/mbarna/Projects/vissl/vissl/models/base_ssl_model.py", line 180, in forward
return self.single_input_forward(batch, self._output_feature_names, self.heads)
File "/home/mbarna/Projects/vissl/vissl/models/base_ssl_model.py", line 138, in single_input_forward
return self.heads_forward(feats, heads)
File "/home/mbarna/Projects/vissl/vissl/models/base_ssl_model.py", line 159, in heads_forward
output = head(output)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/mbarna/Projects/vissl/vissl/models/heads/mlp.py", line 111, in forward
out = self.clf(batch)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/apex/amp/wrap.py", line 28, in wrapper
return orig_fn(*new_args, **kwargs)
File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 384x4)
After digging through the VISSL code, specifically the heads_forward() function in vissl/models/base_ssl_model.py, it looks like this bit of code is applying the heads in series rather than in parallel:
# Example case: Head consisting of several layers
elif (len(heads) > 1) and (len(feats) == 1):
output = feats[0]
for head in heads:
output = head(output)
# our model is multiple output.
return [output]
I have checked that the head layers are set up in parallel:
It looks like I want the function to use the first if clause instead but this seems to be meant for extracting outputs from multiple layers in the trunk instead of using the final layer for multiple heads. I assume that I must be missing something in the configuration setup.
Incidentally, I did not see any guidance on how to structure the disk_filelist labels for multiple heads so I assumed it would be a 2d array, where the columns represent the different heads:
Hello,
I'm trying to use a pre-trained XCiT trunk to train 3 parallel heads for a multi-output image classification problem. Basically, I have images that need to be classified across three categories (so each image will receive three labels--one from each head).
I have set up my configuration (pasted at the bottom) to contain three mlp heads for three, four, and three label classes, respectively. When I try to run a forward training pass, I get this error:
After digging through the VISSL code, specifically the
heads_forward()
function invissl/models/base_ssl_model.py
, it looks like this bit of code is applying the heads in series rather than in parallel:I have checked that the head layers are set up in parallel:
It looks like I want the function to use the first
if
clause instead but this seems to be meant for extracting outputs from multiple layers in the trunk instead of using the final layer for multiple heads. I assume that I must be missing something in the configuration setup.Incidentally, I did not see any guidance on how to structure the
disk_filelist
labels for multiple heads so I assumed it would be a 2d array, where the columns represent the different heads:I was working on implementing my own loss function for this.
Here is the full config. Thanks for your help!