Clay-foundation / model

The Clay Foundation Model (in development)
https://clay-foundation.github.io/model/
Apache License 2.0
347 stars 44 forks source link

Error in tutorials/reconstrucion.ipynb #288

Open robmarkcole opened 3 months ago

robmarkcole commented 3 months ago

Updated to use CHECKPOINT_PATH = "https://clay-model-ckpt.s3.amazonaws.com/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt"

At the cell which creates the datamodule I get:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 12
      1 # For model training, we stack chips from one sensor into batches of size 128.
      2 # This reduces the num_workers we need to load the batches and speeds up the
      3 # training process. Here, although the batch size is 1, the data module reads
      4 # batch of size 128.
      5 dm = ClayDataModule(
      6     data_dir=DATA_DIR,
      7     metadata_path=METADATA_PATH,
   (...)
     10     num_workers=1,
     11 )
---> 12 dm.setup(stage="fit")

File ~/model/src/datamodule.py:166, in ClayDataModule.setup(self, stage)
    163 print(f"Total number of chips: {len(chips_path)}")
    165 if stage == "fit":
--> 166     trn_paths, val_paths = train_test_split(
    167         chips_path,
    168         test_size=(1 - self.split_ratio),
    169         stratify=chips_platform,
    170         shuffle=True,
    171     )
    173     self.trn_ds = EODataset(
    174         chips_path=trn_paths,
    175         size=self.size,
    176         platforms=self.platforms,
    177         metadata=self.metadata,
    178     )
    179     self.trn_sampler = ClaySampler(
    180         dataset=self.trn_ds,
    181         platforms=self.platforms,
    182         batch_size=self.batch_size,
    183     )

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:213, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    207 try:
    208     with config_context(
    209         skip_parameter_validation=(
    210             prefer_skip_nested_validation or global_skip_validation
    211         )
    212     ):
--> 213         return func(*args, **kwargs)
    214 except InvalidParameterError as e:
    215     # When the function is just a wrapper around an estimator, we allow
    216     # the function to delegate validation to the estimator, but we replace
    217     # the name of the estimator by the name of the function in the error
    218     # message to avoid confusion.
    219     msg = re.sub(
    220         r"parameter of \w+ must be",
    221         f"parameter of {func.__qualname__} must be",
    222         str(e),
    223     )

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.11/site-packages/sklearn/model_selection/_split.py:2660, in train_test_split(test_size, train_size, random_state, shuffle, stratify, *arrays)
   2657 arrays = indexable(*arrays)
   2659 n_samples = _num_samples(arrays[0])
-> 2660 n_train, n_test = _validate_shuffle_split(
   2661     n_samples, test_size, train_size, default_test_size=0.25
   2662 )
   2664 if shuffle is False:
   2665     if stratify is not None:

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.11/site-packages/sklearn/model_selection/_split.py:2308, in _validate_shuffle_split(n_samples, test_size, train_size, default_test_size)
   2305 n_train, n_test = int(n_train), int(n_test)
   2307 if n_train == 0:
-> 2308     raise ValueError(
   2309         "With n_samples={}, test_size={} and train_size={}, the "
   2310         "resulting train set will be empty. Adjust any of the "
   2311         "aforementioned parameters.".format(n_samples, test_size, train_size)
   2312     )
   2314 return n_train, n_test

ValueError: With n_samples=0, test_size=0.19999999999999996 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.

I assume there is an additional step to download the required data to DATA_DIR which is not listed here