A PyTorch implementation for Adversarial Discriminative Domain Adaptation.
I only test on MNIST -> USPS, you can just run the following command:
python3 main.py
In this experiment, I use three types of network. They are very simple.
LeNet encoder
LeNetEncoder (
(encoder): Sequential (
(0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
(2): ReLU ()
(3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
(4): Dropout2d (p=0.5)
(5): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
(6): ReLU ()
)
(fc1): Linear (800 -> 500)
)
LeNet classifier
LeNetClassifier (
(fc2): Linear (500 -> 10)
)
Discriminator
Discriminator (
(layer): Sequential (
(0): Linear (500 -> 500)
(1): ReLU ()
(2): Linear (500 -> 500)
(3): ReLU ()
(4): Linear (500 -> 2)
(5): LogSoftmax ()
)
)
MNIST (Source) | USPS (Target) | |
---|---|---|
Source Encoder + Source Classifier | 99.140000% | 83.978495% |
Target Encoder + Source Classifier | 97.634409% |
Domain Adaptation does work (97% vs 83%).