Open mzweilin opened 9 months ago
Have you looked at torch's collate functionality? That can walk python data structure and apply a function. Might be helpful to reuse that here.
torch.utils.data.default_convert()
is closer to our convert()
, but it is not designed to be extensible with internal if-else.
We are not collating data here, so I don't want to abuse default_collate()
or collate()
.
What does this PR do?
This PR adds a recursive convertor
mart.transforms.tensor_array.convert
that converts between Numpy arrays and PyTorch tensors hidden in complex data structures.This is useful when running MART attacks in ARMORY because the Numpy data structure is used in ARMORY.
Type of change
Please check all relevant options.
Testing
Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.
pytest
CUDA_VISIBLE_DEVICES=0 python -m mart experiment=CIFAR10_CNN_Adv trainer=gpu trainer.precision=16
reports 70% (21 sec/epoch).CUDA_VISIBLE_DEVICES=0,1 python -m mart experiment=CIFAR10_CNN_Adv trainer=ddp trainer.precision=16 trainer.devices=2 model.optimizer.lr=0.2 trainer.max_steps=2925 datamodule.ims_per_batch=256 datamodule.world_size=2
reports 70% (14 sec/epoch).Before submitting
pre-commit run -a
command without errorsDid you have fun?
Make sure you had fun coding 🙃