Updated to use CHECKPOINT_PATH = "https://clay-model-ckpt.s3.amazonaws.com/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt"
At the cell which creates the datamodule I get:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 12
1 # For model training, we stack chips from one sensor into batches of size 128.
2 # This reduces the num_workers we need to load the batches and speeds up the
3 # training process. Here, although the batch size is 1, the data module reads
4 # batch of size 128.
5 dm = ClayDataModule(
6 data_dir=DATA_DIR,
7 metadata_path=METADATA_PATH,
(...)
10 num_workers=1,
11 )
---> 12 dm.setup(stage="fit")
File ~/model/src/datamodule.py:166, in ClayDataModule.setup(self, stage)
163 print(f"Total number of chips: {len(chips_path)}")
165 if stage == "fit":
--> 166 trn_paths, val_paths = train_test_split(
167 chips_path,
168 test_size=(1 - self.split_ratio),
169 stratify=chips_platform,
170 shuffle=True,
171 )
173 self.trn_ds = EODataset(
174 chips_path=trn_paths,
175 size=self.size,
176 platforms=self.platforms,
177 metadata=self.metadata,
178 )
179 self.trn_sampler = ClaySampler(
180 dataset=self.trn_ds,
181 platforms=self.platforms,
182 batch_size=self.batch_size,
183 )
File /home/zeus/miniconda3/envs/cloudspace/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:213, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
207 try:
208 with config_context(
209 skip_parameter_validation=(
210 prefer_skip_nested_validation or global_skip_validation
211 )
212 ):
--> 213 return func(*args, **kwargs)
214 except InvalidParameterError as e:
215 # When the function is just a wrapper around an estimator, we allow
216 # the function to delegate validation to the estimator, but we replace
217 # the name of the estimator by the name of the function in the error
218 # message to avoid confusion.
219 msg = re.sub(
220 r"parameter of \w+ must be",
221 f"parameter of {func.__qualname__} must be",
222 str(e),
223 )
File /home/zeus/miniconda3/envs/cloudspace/lib/python3.11/site-packages/sklearn/model_selection/_split.py:2660, in train_test_split(test_size, train_size, random_state, shuffle, stratify, *arrays)
2657 arrays = indexable(*arrays)
2659 n_samples = _num_samples(arrays[0])
-> 2660 n_train, n_test = _validate_shuffle_split(
2661 n_samples, test_size, train_size, default_test_size=0.25
2662 )
2664 if shuffle is False:
2665 if stratify is not None:
File /home/zeus/miniconda3/envs/cloudspace/lib/python3.11/site-packages/sklearn/model_selection/_split.py:2308, in _validate_shuffle_split(n_samples, test_size, train_size, default_test_size)
2305 n_train, n_test = int(n_train), int(n_test)
2307 if n_train == 0:
-> 2308 raise ValueError(
2309 "With n_samples={}, test_size={} and train_size={}, the "
2310 "resulting train set will be empty. Adjust any of the "
2311 "aforementioned parameters.".format(n_samples, test_size, train_size)
2312 )
2314 return n_train, n_test
ValueError: With n_samples=0, test_size=0.19999999999999996 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.
I assume there is an additional step to download the required data to DATA_DIR which is not listed here
Updated to use
CHECKPOINT_PATH = "https://clay-model-ckpt.s3.amazonaws.com/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt"
At the cell which creates the datamodule I get:
I assume there is an additional step to download the required data to
DATA_DIR
which is not listed here