KaiyangZhou / Dassl.pytorch

A PyTorch toolbox for domain generalization, domain adaptation and semi-supervised learning.
MIT License
1.23k stars 171 forks source link

A new implementation of MixStyle #23

Open KaiyangZhou opened 3 years ago

KaiyangZhou commented 3 years ago

We have improved the implementation of MixStyle to make it more flexible.

Recall that MixStyle has two versions: random mixing and cross-domain mixing. The former randomly shuffles the batch dimension while the latter mixes the 1st half in a batch with the 2nd half.

After merging MixStyle2 to MixStyle, the two versions are now managed by a new variable called self.mix, which takes as input either random or crossdomain that correspond to the two versions respectively. This variable can be set during initialization, e.g., self.mixstyle = MixStyle(mix='random'). It can also be changed on-the-fly. For instance, say you wanna apply random mixing at current step, simply do model.apply(random_mixstyle), or model.apply(crossdomain_mixstyle) if you prefer the cross-domain mixing manner.

We have also added new context managers to manage mixstyle in the forward pass. Say your model has MixStyle layers which were initially activated and you would like to deactivate them at a certain time, you can do

# print(MixStyle._activated): True
with run_without_mixstyle(model):
    # print(MixStyle._activated): False
    output = model(input)
# print(MixStyle._activated): True

Otherwise if you want to use MixStyle layers which were initially deactivated, you can do

# print(MixStyle._activated): False
with run_with_mixstyle(model):
    # print(MixStyle._activated): True
    output = model(input)
# print(MixStyle._activated): False

You can also change self.mix while using run_with_mixstyle, e.g.

# print(MixStyle._activated): False
# print(MixStyle.mix): random
with run_with_mixstyle(model, mix='crossdomain'):
    # print(MixStyle._activated): True
    # print(MixStyle.mix): crossdomain
    output = model(input)
# print(MixStyle._activated): False
# print(MixStyle.mix): crossdomain

But note that the change in self.mix during run_with_mixstyle is permanent unless you manually use model.apply(random_mixstyle) or model.apply(crossdomain_mixstyle) to modify the variable.