KevinMusgrave / pytorch-adapt

Domain adaptation made easy. Fully featured, modular, and customizable.
https://kevinmusgrave.github.io/pytorch-adapt/
MIT License
359 stars 15 forks source link

Use DANN with target labels #95

Closed rtaiello closed 1 year ago

rtaiello commented 1 year ago

Hi @KevinMusgrave,

How could I use DANN with target labels, I tried to do that:

from pytorch_adapt.hooks import DANNHook, CLossHook, FeaturesAndLogitsHook
G.count, C.count, D.count = 0, 0, 0
f_hook =  FeaturesAndLogitsHook(domains = ["src", "target"])
c_hook = CLossHook(f_hook=f_hook)
hook = DANNHook(opts,c_hook=c_hook)
model_counts = validate_hook(hook, list(data.keys()))
outputs, losses = hook({**models, **data})
print_info(model_counts, outputs, losses, G, C, D)

But I'm having this issue:

ValueError: in DANNHook: __call__
in ChainHook: __call__
in OptimizerHook: __call__
in ChainHook: __call__
in ChainHook: __call__
in CLossHook: __call__
too many values to unpack (expected 1)

Thanks in advance!

KevinMusgrave commented 1 year ago

I pushed a change to the dev branch:

pip install pytorch-adapt==0.0.82.dev0

Now your code should work, assuming that target_labels is available in your data dict.

You can also pass in domains directly to CLossHook:

c_hook = CLossHook(domains=["src", "target"])

The reason why your code broke is that CLossHook was hard-coded to expect "src" logits only. So when FeaturesAndLogitsHook returned two tensors for logits, it was "too many values to unpack".

Let me know if the latest version works and I'll push it to 0.0.82.

rtaiello commented 1 year ago

It worked, many thanks !