Closed crazyn2 closed 8 months ago
I've fixed the bug this piece of codes by reset self.gen before each dataloader
class CIFAR10Dm(pl.LightningDataModule):
__doc__ = r"""Initialize cifar10 cfg.
.. seealso::
See :attr:`~dataset.fmnist` for related property.
Args:
batch_size:
batch_size parameter of dataloader
normal_class:
normal class which's labelled 0
seed:
dataloader workers's initial seed
radio:
rate of abnormal samples that were classified as normalities
"""
def __init__(
self,
batch_size,
normal_class,
seed,
radio=0.0,
num_workers=3,
root="./data/",
dataset_name="cifar10",
gcn=True,
):
"""Initialize cifar10 cfg.
.. seealso::
See :attr:`~dataset.fmnist` for related property.
Args:
batch_size:
batch_size parameter of dataloader
normal_class:
normal class which's labelled 0
seed:
dataloader workers's initial seed
radio:
rate of abnormal samples that were classified as normalities
"""
super().__init__()
# normal class only one class per training set
self.save_hyperparameters()
pl.seed_everything(seed, workers=True)
self.batch_size = batch_size
self.dataset_name = dataset_name
self.root = root
# 污染数据比例
self.radio = radio
self.normal_class = normal_class
self.num_workers = num_workers
self.normal_classes = tuple([normal_class])
self.outlier_classes = list(range(0, 10))
self.outlier_classes.remove(normal_class)
self.seed = seed
self.gcn = gcn
self.gen = torch.Generator()
self.gen.manual_seed(self.seed)
# def prepare_data(self):
# # download
# CIFAR10(self.root, train=True, download=True)
# CIFAR10(self.root, train=False, download=True)
# def setup(self, stage: str) -> None:
# Pre-computed min and max values (after applying GCN)
# from train data per class
# global_contrast_normalization
# def setup(self, stage: str) -> None:
min_max = [(-28.94083453598571, 13.802961825439636),
(-6.681770233365245, 9.158067708230273),
(-34.924463588638204, 14.419298165027628),
(-10.599172931391799, 11.093187820377565),
(-11.945022995801637, 10.628045447867583),
(-9.691969487694928, 8.948326776180823),
(-9.174940012342555, 13.847014686472365),
(-6.876682005899029, 12.282371383343161),
(-15.603507135507172, 15.2464923804279),
(-6.132882973622672, 8.046098172351265)]
gcn_transform = [
transforms.Lambda(
lambda x: global_contrast_normalization(x, scale='l1')),
transforms.Normalize([min_max[self.normal_class][0]] * 3, [
min_max[self.normal_class][1] - min_max[self.normal_class][0]
] * 3)
]
transforms_list = [transforms.ToTensor()]
if self.gcn:
transforms_list += gcn_transform
else:
transforms_list.append(transforms.Normalize([0.5] * 3, [0.5] * 3))
transform = transforms.Compose(transforms_list)
# transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Lambda(
# lambda x: global_contrast_normalization(x, scale='l1')),
# transforms.Normalize([min_max[self.normal_class][0]] * 3, [
# min_max[self.normal_class][1] - min_max[self.normal_class][0]
# ] * 3)
# ])
target_transform = transforms.Lambda(
lambda x: int(x in self.outlier_classes))
# if stage == "fit":
train_cifar10 = CIFAR10(
root=self.root,
train=True,
transform=transform,
# download=True,
target_transform=target_transform,
)
train_indices = [
idx for idx, target in enumerate(train_cifar10.targets)
if target in self.normal_classes
]
dirty_indices = [
idx for idx, target in enumerate(train_cifar10.targets)
if target not in self.normal_classes
]
train_indices += sample(
dirty_indices,
int(len(train_indices) * self.radio / (1 - self.radio)))
# dataloader shuffle=True will mix the order of normal and abnormal
# extract the normal class of cifar10 train dataset
self.train_cifar10 = Subset(train_cifar10, train_indices)
# if stage == "test":
self.test_cifar10 = CIFAR10(
root=self.root,
train=False,
transform=transform,
# download=True,
target_transform=target_transform,
)
def train_dataloader(self):
self.gen = torch.Generator()
self.gen.manual_seed(self.seed)
return DataLoader(self.train_cifar10,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=seed_worker,
generator=self.gen,
persistent_workers=True,
shuffle=True,
drop_last=True)
def test_dataloader(self):
self.gen = torch.Generator()
self.gen.manual_seed(self.seed)
return DataLoader(
self.test_cifar10,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=seed_worker,
persistent_workers=True,
generator=self.gen,
# shuffle=True,
drop_last=True)
def load_dataloader(self, dataset):
self.gen = torch.Generator()
self.gen.manual_seed(self.seed)
return DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=seed_worker,
persistent_workers=True,
generator=self.gen,
# shuffle=True,
drop_last=True)
def val_dataloader(self):
self.gen = torch.Generator()
self.gen.manual_seed(self.seed)
return DataLoader(
self.test_cifar10,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=seed_worker,
persistent_workers=True,
generator=self.gen,
# shuffle=True,
drop_last=True)
Bug description
Same logical piece of code produced different results such as svdd_roc_auc_sk
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Environment
Current environment
* CUDA: - GPU: - NVIDIA GeForce RTX 3090 - NVIDIA GeForce RTX 3080 Ti - NVIDIA GeForce RTX 3080 Ti - available: True - version: 12.1 * Lightning: - kmeans-pytorch: 0.3 - lightning: 2.1.2 - lightning-cloud: 0.5.39 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.9.post0 - torch: 2.1.1 - torch-tb-profiler: 0.4.3 - torchaudio: 2.1.0 - torchdata: 0.7.1 - torchmetrics: 1.2.0 - torchsummary: 1.5.1 - torchtext: 0.16.1 - torchvision: 0.16.1 * Packages: - absl-py: 2.0.0 - addict: 2.4.0 - aioboto3: 11.3.0 - aiobotocore: 2.6.0 - aiofiles: 23.2.1 - aiohttp: 3.8.5 - aiohttp-cors: 0.7.0 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - alembic: 1.12.0 - altair: 5.1.2 - annotated-types: 0.5.0 - anyio: 3.7.1 - arrow: 1.3.0 - astor: 0.8.1 - async-timeout: 4.0.3 - attrs: 23.1.0 - backoff: 2.2.1 - beautifulsoup4: 4.12.2 - black: 23.10.1 - blessed: 1.20.0 - blinker: 1.6.3 - boto3: 1.28.17 - botocore: 1.31.17 - bottle: 0.12.25 - brotli: 1.0.9 - cachetools: 5.3.1 - certifi: 2023.11.17 - cffi: 1.16.0 - charset-normalizer: 2.0.4 - click: 8.1.7 - cloudpickle: 2.2.1 - cmaes: 0.10.0 - colorama: 0.4.6 - colorful: 0.5.5 - colorlog: 6.7.0 - contextlib2: 21.6.0 - contourpy: 1.1.1 - croniter: 1.4.1 - cryptography: 41.0.7 - cycler: 0.12.0 - databricks-cli: 0.18.0 - dateutils: 0.6.12 - deepdiff: 6.6.0 - deeplake: 3.8.0 - dill: 0.3.7 - distlib: 0.3.8 - docker: 6.1.3 - einops: 0.7.0 - entrypoints: 0.4 - et-xmlfile: 1.1.0 - fastapi: 0.103.2 - ffmpy: 0.3.1 - filelock: 3.11.0 - flake8: 6.0.0 - flask: 2.3.3 - fonttools: 4.43.0 - frozenlist: 1.4.0 - fsspec: 2023.10.0 - gitdb: 4.0.10 - gitpython: 3.1.37 - gmpy2: 2.1.2 - google-api-core: 2.15.0 - google-auth: 2.23.2 - google-auth-oauthlib: 1.0.0 - googleapis-common-protos: 1.62.0 - gpustat: 1.1.1 - gradio-client: 0.7.0 - greenlet: 3.0.0 - grpcio: 1.59.0 - gunicorn: 21.2.0 - gym: 0.26.2 - gym-notices: 0.0.8 - h11: 0.14.0 - httpcore: 1.0.2 - httpx: 0.25.2 - hub: 3.0.1 - huggingface-hub: 0.19.4 - humbug: 0.3.2 - idna: 3.4 - imageio: 2.31.6 - importlib-metadata: 6.8.0 - importlib-resources: 6.1.1 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.2 - json-tricks: 3.17.3 - jsonschema: 4.19.1 - jsonschema-specifications: 2023.7.1 - kiwisolver: 1.4.5 - kmeans-pytorch: 0.3 - kornia: 0.7.0 - lazy-loader: 0.3 - libdeeplake: 0.0.83 - lightning: 2.1.2 - lightning-cloud: 0.5.39 - lightning-utilities: 0.9.0 - mako: 1.2.4 - markdown: 3.4.4 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - matplotlib: 3.8.0 - mccabe: 0.7.0 - mdurl: 0.1.2 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - mlflow: 2.7.1 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.4 - multiprocess: 0.70.15 - mypy-extensions: 1.0.0 - nest-asyncio: 1.5.8 - networkx: 3.1 - nni: 3.0 - numcodecs: 0.12.0 - numpy: 1.26.2 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-ml-py: 12.535.133 - nvidia-nccl-cu12: 2.18.1 - nvidia-nvjitlink-cu12: 12.3.101 - nvidia-nvtx-cu12: 12.1.105 - oauthlib: 3.2.2 - opencensus: 0.11.3 - opencensus-context: 0.1.3 - opencv-python: 4.8.1.78 - openpyxl: 3.1.2 - optuna: 3.5.0 - optuna-dashboard: 0.12.0 - ordered-set: 4.1.0 - orjson: 3.9.10 - packaging: 23.2 - pandas: 2.1.1 - pathos: 0.3.1 - pathspec: 0.11.2 - pep8: 1.7.1 - pillow: 10.0.1 - pip: 23.3.1 - platformdirs: 3.11.0 - pox: 0.3.3 - ppft: 1.7.6.7 - prettytable: 3.9.0 - prometheus-client: 0.19.0 - protobuf: 4.24.4 - psutil: 5.9.5 - py-spy: 0.3.14 - pyarrow: 13.0.0 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycodestyle: 2.10.0 - pycparser: 2.21 - pydantic: 1.9.1 - pydantic-core: 2.14.5 - pydub: 0.25.1 - pyflakes: 3.0.1 - pygame: 2.5.2 - pygments: 2.16.1 - pyjwt: 2.8.0 - pymysql: 1.1.0 - pyopenssl: 23.2.0 - pyparsing: 3.1.1 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pythonwebhdfs: 0.2.3 - pytorch-lightning: 2.0.9.post0 - pytz: 2023.3.post1 - pyyaml: 6.0.1 - querystring-parser: 1.2.4 - ray: 2.8.1 - readchar: 4.0.5 - referencing: 0.30.2 - regex: 2023.10.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - responses: 0.24.1 - rich: 13.6.0 - rpds-py: 0.10.4 - rsa: 4.9 - s3transfer: 0.6.2 - safetensors: 0.4.0 - schema: 0.7.5 - scikit-image: 0.22.0 - scikit-learn: 1.3.1 - scipy: 1.11.3 - semantic-version: 2.10.0 - setuptools: 68.0.0 - shellingham: 1.5.4 - simplejson: 3.19.2 - six: 1.16.0 - smart-open: 6.4.0 - smmap: 5.0.1 - sniffio: 1.3.0 - soupsieve: 2.5 - sqlalchemy: 2.0.21 - sqlparse: 0.4.4 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.12 - tabulate: 0.9.0 - tensorboard: 2.14.1 - tensorboard-data-server: 0.7.1 - tensorboardx: 2.6.2.2 - threadpoolctl: 3.2.0 - tifffile: 2023.9.26 - tokenizers: 0.15.0 - tomli: 2.0.1 - tomlkit: 0.12.0 - toolz: 0.12.0 - torch: 2.1.1 - torch-tb-profiler: 0.4.3 - torchaudio: 2.1.0 - torchdata: 0.7.1 - torchmetrics: 1.2.0 - torchsummary: 1.5.1 - torchtext: 0.16.1 - torchvision: 0.16.1 - tqdm: 4.66.1 - traitlets: 5.11.2 - transformers: 4.35.2 - triton: 2.1.0 - typeguard: 4.1.2 - typer: 0.9.0 - types-python-dateutil: 2.8.19.14 - typing-extensions: 4.7.1 - tzdata: 2023.3 - urllib3: 1.26.18 - uvicorn: 0.23.2 - virtualenv: 20.21.0 - wcwidth: 0.2.8 - websocket-client: 1.6.3 - websockets: 11.0.3 - werkzeug: 3.0.0 - wheel: 0.41.2 - wrapt: 1.15.0 - yapf: 0.40.2 - yarl: 1.9.2 - zipp: 3.17.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.5 - release: 6.2.0-39-generic - version: #40~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Nov 16 10:53:04 UTC 2More info
No response