fidelity / stoke

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.
https://fidelity.github.io/stoke/
Apache License 2.0
66 stars 3 forks source link

TypeError: intercept_args() got an unexpected keyword argument 'multiprocessing_context' #23

Closed rushi-the-neural-arch closed 2 years ago

rushi-the-neural-arch commented 2 years ago

I think the term multiprocessing_context isn't being used anywhere concretely but still appears in the Dataloader object which causes the issue. This could be a simple bug as well but couldn't figure out the exact issue. The error logs are as below

File "/home/..../stoke/stoke.py", line 835, in DataLoader persistent_workers=persistent_workers, File "/..../stoke/data.py", line 127, in __init__ persistent_workers=persistent_workers, TypeError: intercept_args() got an unexpected keyword argument 'multiprocessing_context' ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 16888) of binary: /anaconda/envs/py37_default/bin/python

ncilfone commented 2 years ago

Can you list the following to help debug:

ncilfone commented 2 years ago

Right now the stoke code here just shims the basic PyTorch DataLoader here to deal with underlying device placement.

If I had to guess, this error is most likely coming from the PyTorch end and not the Stoke end, and is probably a configuration error with something on the basic PyTorch end... But if you provide more detail I can investigate further.

rushi-the-neural-arch commented 2 years ago

Yeah sure, here is the info

Environment:

rushi-the-neural-arch commented 2 years ago

Here is the sample code which I am following from the documentation

    amp_config = AMPConfig(
            init_scale=2.**14
        )

    # Custom DDP configuration
    # Automatically swap out batch_norm layers with sync_batch_norm layers
    # Notice here we have to deal with the local rank parameter that DDP needs (from env or cmd line)
    ddp_config = DDPConfig(
        local_rank= int(os.getenv('LOCAL_RANK')),
        convert_to_sync_batch_norm=True
    )

    # Custom OSS configuration
    # activate broadcast_fp16 -- Compress the model shards in fp16 before sharing them in between ranks
    oss_config = FairscaleOSSConfig(
        broadcast_fp16=True
    )

    print("===> Building model")
    model = Net(upscale_factor=2)

    loss = feat_loss

    #optimizer = optim.AdamW(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=opt.weight_decay),
    optimizer = StokeOptimizer(
         optimizer = AdamW,
         optimizer_kwargs = {
             "lr" : opt.lr,
             "betas" : (0.9, 0.99),
             "eps" : 1e-8,
             "weight_decay" : opt.weight_decay
         }

    )

    # Build the object with the correct options/choices (notice how DistributedOptions and FP16Options are already provided
# to make choices simple) and configurations (passed to configs as a list)
    stoke_model = Stoke(
        model=model,
        verbose=False,     # verbose just prints out stuff, throws an error somewhere so disabled it
        optimizer=optimizer,
        loss=loss,
        batch_size_per_device=opt.batchSize,
        gpu=True,
        fp16= None, #FP16Options.amp,
        distributed= "ddp", #DistributedOptions.ddp,
        fairscale_oss=True,
        fairscale_sddp=True,
#         fairscale_fsdp = True,
        grad_accum_steps=4,
        grad_clip=opt.grad_clip,
        configs=[amp_config, ddp_config, oss_config]
    )

    print("===> Loading datasets")

    input_path = opt.trainDir + "Faces256" +  "/"                    
    target_path = opt.trainDir +  "Faces512" + "/"                  

    full_dataset = CustomDataset(input_path, target_path)

    train_size = int(0.99 * len(full_dataset))
    test_size = len(full_dataset) - train_size

    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

    train_sampler = DistributedSampler(
        dataset=train_dataset,
        num_replicas=stoke_model.world_size,
        rank=stoke_model.rank
)

    val_sampler = DistributedSampler(
        val_dataset,
        num_replicas=stoke_model.world_size,
        rank=stoke_model.rank
    )

    threads = opt.threads

    train_dataloader = stoke_model.DataLoader(
        dataset=train_dataset,
        collate_fn=lambda batch: dataset.collate_fn(batch),
        sampler=train_sampler,
        persistent_workers = True,
        #multiprocessing_context = "forkserver",
        num_workers=threads
)

    val_dataloader = stoke_model.DataLoader(
        dataset=val_dataset,
        collate_fn=lambda batch: dataset.collate_fn(batch),
        sampler=val_sampler,
        persistent_workers = True,
        #multiprocessing_context = "forkserver",
        num_workers=8
)

Also, I want to highlight that setting the verbose parameter True, threw an attrs class error which wasn't significant so I just turned it off and another main thing is the distributed= "ddp" setting, if I keep distributed= DistributedOptions.ddp according to the docs, it throws an error which states that the DDP option is not enabled. This is the first time I am experimenting with stoke so apologies if some of these issues sound silly!

Thanks!

rushi-the-neural-arch commented 2 years ago

If you want the complete code snippet for reference, I have uploaded it here https://gist.github.com/rushi-the-neural-arch/bee47ba87e5ddabf0cb47def9bc0b013

python -m torch.distributed.launch Stoke-DDP.py --projectName "PyTorch-4K-2X" --batchSize 20 --nEpochs 2 --lr 1e-3 --threads 8

ncilfone commented 2 years ago

Can you trying running this (per here) as your launcher:

python -m torch.distributed.launch Stoke-DDP.py --projectName "PyTorch-4K-2X" --batchSize 20 --nEpochs 2 --lr 1e-3 --threads 8 --use-env

You might not be creating the local rank env variable unless you are parsing the arg per Pytorch instructions. The --use_envflag is the easiest way for Stoke to handle the device rank.

Also, I want to highlight that setting the verbose parameter True, threw an attrs class error which wasn't significant so I just turned it off and another main thing is the distributed= "ddp" setting, if I keep distributed= DistributedOptions.ddp

Can you open another issue with these errors? Trying to squash as many bugs to get to v1.0 that's stable. Helps to see all these error on different system setups/clusters...

rushi-the-neural-arch commented 2 years ago

Hi, I tried using --use_env but it didn't show any effect, the same error repeats. Also it gives a FutureWarning stating torch.distributed.launch is deprecated and will be removed in future. Use Torchrun. Note that --use_env is set by default in torchrun

Here is the log, and the error is the same as previously mentioned

FutureWarning: The module torch.distributed.launch is deprecated and will be removed in future. Use torchrun. Note that --use_env is set by default in torchrun. If your script expects --local_rank argument to be set, please change it to read from os.environ['LOCAL_RANK'] instead. See https://pytorch.org/docs/stable/distributed.html#launch-utility for further instructions

Also, yes I will create a new issue for the other bugs that I have encountered, glad to help you out, Thanks!

ncilfone commented 2 years ago

Trying to isolate the error... Can you try running with persistent_workers=False as I think this is where it's occurring in the underlying...

rushi-the-neural-arch commented 2 years ago

Actually, the persistent_workers=False is set by default, still, it threw the error, I set it to True just to see if it changed the error somehow, but it didn't. I again set it to False and the same error repeats. Tbh I don't understand why it says that multiprocessing_context is an unexpected keyword argument? This argument is being used heavily in the PyTorch code mentioned here and as you said, we are just wrapping the PyTorch code in Stoke here

ncilfone commented 2 years ago

Ok... Try the branch issue_23 and let me know what happens...

I've made modifications to how the DataLoader class is called and how the multiprocessing method is set.

rushi-the-neural-arch commented 2 years ago

Umm.. I installed the issue_23 branch via pip install git+https://github.com/fidelity/stoke.git@issue_23 which shows the stoke version as v0.2.0+6.gf22980b so I guess the new branch is correctly installed but it still throws the same error.

File "Stoke-DDP.py", line 268, in main num_workers=threads File "/home/...stoke/stoke.py", line 847, in DataLoader persistent_workers=persistent_workers, File "/home.../stoke/data.py", line 127, in __init__ persistent_workers=persistent_workers, TypeError: intercept_args() got an unexpected keyword argument 'multiprocessing_context'

Configuration

stoke_model = Stoke(
    model=model,
    verbose=True,
    optimizer=optimizer,
    loss=loss,
    batch_size_per_device=opt.batchSize,
    gpu=True,
    fp16= None, #FP16Options.amp,
    distributed=DistributedOptions.ddp.value
    fairscale_oss=True,
    fairscale_sddp=True,
    grad_accum_steps=1,
    grad_clip=ClipGradNormConfig(max_norm = opt.grad_clip, norm_type=2.0),
    configs=[amp_config, ddp_config, oss_config]
)

train_dataloader = stoke_model.DataLoader(
    dataset=train_dataset,
    collate_fn=lambda batch: dataset.collate_fn(batch),
    sampler=train_sampler,
    num_workers=opt.threads
)
ncilfone commented 2 years ago

Dang... this is an odd one. I've completely removed the multiprocessing_context kwarg on that branch.... Can you try it again? I want to see if there is something hidden deeper in PyTorch that's being called on init that I currently can't track...

rushi-the-neural-arch commented 2 years ago

Umm.. now it throws the same unexpected keyword argument error but for the next init argument generator

TypeError: intercept_args() got an unexpected keyword argument 'generator'

/stoke/data.py", line 125, in __init__

ncilfone commented 2 years ago

OK. This is actually helpful... seems like the init from the super call isn't dealing with **kwargs correctly... when it forwards on to this magical intercept_args function (that I can't seem to find anywhere).

Wondering if this is a torch version issue... I've been testing most things in the included Dockerfiles here which is v1.8.1

rushi-the-neural-arch commented 2 years ago

Ohh ohkay, glad we could actually find the cause. I am using torch version '1.10.0+cu102' so I guess this would be the sole reason (version mismatch) for the error. Is there any quick fix for this or any alternative suggestions?? I can try downgrading the PyTorch version to 1.8.1 as the last option

ncilfone commented 2 years ago

Still trying to replicate this error in our/another environment... Hang tight

ncilfone commented 2 years ago

@rushi-the-neural-arch Can you do me a favor and let me know if any of the CIFAR10 examples are able to run in your environment?

These are quite minimal examples that should isolate the error to Stoke/PyTorch and not some other dependency that might be doing something unknown.

rushi-the-neural-arch commented 2 years ago

Ya sure! I will let you know in a while!

rushi-the-neural-arch commented 2 years ago

@ncilfone I ran the CIFAR10 examples on 4 GPUs using DDP and they are running perfectly fine!

python -m torch.distributed.launch --nproc_per_node=4 --use_env stoke/examples/cifar10/train.py \
-c stoke/examples/cifar10/config/ddp-gpu.yaml

On a side note, just for your reference, I faced this issue while running the script - ReboundX - OSError: librebound.cpython-35m-x86_64-linux-gnu.so: cannot open shared object file: No such file or directory

The quick fix to that was to pip uninstall reboundX and then

After this, the script runs perfectly fine ( I haven't modified any of the configs)

ncilfone commented 2 years ago

@ncilfone I ran the CIFAR10 examples on 4 GPUs using DDP and they are running perfectly fine!

python -m torch.distributed.launch --nproc_per_node=4 --use_env stoke/examples/cifar10/train.py \
-c stoke/examples/cifar10/config/ddp-gpu.yaml

@rushi-the-neural-arch Well that's a twist in the story of this bug!

Not completely sure but... my guess is that there might be a dependency in your code that's causing issues with multiprocessing spawning new processes with the full set of keyword args... I've made a few commits on the issue_23 branch to use **kwargs instead of direct keyword reference (shouldn't matter but ynk) a la how Pytorch Lightning handles their DataLoader shims (with **kwargs only). See if that fixes things (as I can't seem to replicate this on any systems I have access to)...

...and if not I think maybe try stripping down your example code to the minimal parts (pytorch and stoke only) with no bells and whistles (wandb, etc.). If that works then we can work upwards to see which deps might be clashing

On a side note, just for your reference, I faced this issue while running the script - ReboundX - OSError: librebound.cpython-35m-x86_64-linux-gnu.so: cannot open shared object file: No such file or directory

The quick fix to that was to pip uninstall reboundX and then

  • pip install rebound
  • pip install reboundX
  • pip install spock-config

After this, the script runs perfectly fine ( I haven't modified any of the configs)

This is super weird... Nothing in Stoke or Spock has a dependency wrt to Rebound/ReboundX. Are you in a venv or just working globally? Might be a residual dep somewhere if it's the latter...

rushi-the-neural-arch commented 2 years ago

sure, let me strip down W&B logging as of now and check again, will report back in some time. Also yeah this seems weird to me that even Spock doesn't use ReboundX, however, I correctly remember that the error was in the line import reboundx as rbx that is what lead me to the above GitHub issue. Nonetheless, if possible, I will try to again reproduce that issue

ncilfone commented 2 years ago

roping #27 into master so you won't need to deal with the branch and will have all the other bug fixes

ncilfone commented 2 years ago

Didn't mean to close. Sorry!

rushi-the-neural-arch commented 2 years ago

Yeah, no worries! Sorry, I was on a leave for few days, I checked back again today, removed all the unnecessary stuff like W&B etc from the script and re-ran it again, but the same issue! :\ TypeError: intercept_args() got an unexpected keyword argument 'multiprocessing_context' Stoke -- Stoke -- Automatically handling moving model input data to GPU(s)... kind of frustrated at this point, it has been more than a week and idk, cant figure this out. This is just one simple bug and I guess if this is solved, I can continue with the training (mostly) hassel free

rushi-the-neural-arch commented 2 years ago

@ncilfone I am searching for similar issues faced by other users and I found the following references. It seems that PyTorch Lightning users also faced the same issue - unexpected argument - multiprocessing_context() / generator() and they've got a fix for it somehow. Let me know if this is helpful in finding a solution/hack for this! I will keep updating this comment whenever I find new relevant references

rushi-the-neural-arch commented 2 years ago

@ncilfone I wrote the same sample script using Fairscale - Github Gist and I can successfully create the data loader and run the model for training (on gloo backend), whereas we are facing issue somewhere while creating the dataloader itself in Stoke. I guess there's a very small bug or mistake which we are overlooking.

ncilfone commented 2 years ago

Hey @rushi-the-neural-arch. Sorry about the lengthy debug here :-/ I know it's frustrating... hopefully we can solve this ASAP and get you on your way. I'm was trying to break up your code into pieces to see if I can isolate what's happening...

stoke_model = Stoke(
    model=model,
    verbose=True,
    optimizer=optimizer,
    loss=loss,
    batch_size_per_device=opt.batchSize,
    gpu=True,
    fp16= None, #FP16Options.amp,
    distributed=DistributedOptions.ddp.value
    fairscale_oss=True,
    fairscale_sddp=True,
    grad_accum_steps=1,
    grad_clip=ClipGradNormConfig(max_norm = opt.grad_clip, norm_type=2.0),
    configs=[amp_config, ddp_config, oss_config]
)

train_dataloader = stoke_model.DataLoader(
    dataset=train_dataset,
    collate_fn=lambda batch: dataset.collate_fn(batch),
    sampler=train_sampler,
    num_workers=opt.threads
)

I think there is an error in this line collate_fn=lambda batch: dataset.collate_fn(batch), perhaps? The dataset object doesn't exist (as far as I can see) so this lambda function is referencing the collate_fn on that non-existent dataset... normally that's reserved for when you need some custom map style functions to assemble a batch. In the CIFAR10 examples note that there is no collate_fn referece:

# Construct the DataLoader
train_loader = cifar_stoke.DataLoader(
    dataset=training_dataset,
    sampler=train_sampler,
    num_workers=configs.DataConfig.n_workers
    if configs.DataConfig.n_workers is not None
    else 0,
)

I think this is maybe a mis-documentation on my end as I was writing the docs based off a dummy example Dataset that had a collate_fn:

class RandomData(torch.utils.data.Dataset):
    def __init__(self, data_config: DataConfig):
        self._config = data_config
        # Make some random data (BxH): H has dim of features in
        self._x_data = torch.rand(self._config.data_len, self._config.n_features)
        self._y_data = torch.randint(0, self._config.n_class, (self._config.data_len,))

    def __len__(self):
        return self._x_data.shape[0]

    def __getitem__(self, idx):
        return self._x_data[idx, :], self._y_data[idx, ]

    @staticmethod
    def collate_fn(batch):
        batch_dict = {'x': [], 'y': []}
        for sample in batch:
            batch_dict['x'].append(sample[0])
            batch_dict['y'].append(sample[1])
        return torch.stack(batch_dict['x'], dim=0), torch.stack(batch_dict['y'])

Try removing the collate_fn ref and see what happens?

rushi-the-neural-arch commented 2 years ago

Umm, I re-ran the script again by removing the collate_fn but couldn't see any change actually. I have just kept the sampler and dataloader simple like this



train_sampler = (
        DistributedSampler(
            dataset=full_dataset,
            num_replicas=stoke_model.world_size,
            rank=stoke_model.rank
        )
    )

train_dataloader = stoke_model.DataLoader(
    dataset=full_dataset,
    sampler=train_sampler,
    num_workers=16,
)
rushi-the-neural-arch commented 2 years ago

On a high level, I have kept the script as simple as this

amp_config = AMPConfig(
    init_scale=2.**14
)

ddp_config = DDPConfig(
    local_rank= int(os.getenv('LOCAL_RANK')),
    convert_to_sync_batch_norm=True
)
oss_config = FairscaleOSSConfig(
    broadcast_fp16=True
)

print("===> Building model")
model = Net(upscale_factor=2)

loss = feat_loss

optimizer = StokeOptimizer(
     optimizer = AdamW,
     optimizer_kwargs = {
         "lr" : 1e-3,          
         "betas" : (0.9, 0.99),
         "eps" : 1e-8,
         "weight_decay" : 1e-4        
     }

 )

stoke_model = Stoke(
    model=model,
    verbose=True,    
    optimizer=optimizer,
    loss=loss,
    batch_size_per_device= 10,     
    gpu=True,   #configs.RunConfig.gpu,
    fp16=None, #configs.RunConfig.fp16,
    distributed=DistributedOptions.ddp.value, 
    fairscale_oss=True, #configs.RunConfig.oss,
    fairscale_sddp=True, #configs.RunConfig.sddp,
    configs= [amp_config, ddp_config, oss_config],     
    grad_accum_steps=1, #configs.RunConfig.grad_accum,
    grad_clip=ClipGradNormConfig(max_norm = 0.1, norm_type=2.0),
)

full_dataset = CustomDataset(input_path, target_path)

train_sampler = (
    DistributedSampler(
        dataset=full_dataset,
        num_replicas=stoke_model.world_size,
        rank=stoke_model.rank
    )
)

train_dataloader = stoke_model.DataLoader(
    dataset=full_dataset,
    sampler=train_sampler,
    num_workers=16,

)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start = 0.9, steps_per_epoch=len(train_dataloader), epochs=epochs)

for epoch in range(2): 
    train_loss = train(train_dataloader, stoke_model, scheduler, epoch)

And my CustomDataset class for dataset is also pretty straightforward like below

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

def is_image(path: Path):
    return path.suffix in IMG_EXTENSIONS

class CustomDataset(Dataset):
    def __init__(self, input_image_path, target_image_path):
        self.input_path = input_image_path
        self.target_path = target_image_path

    def __len__(self):
        input_images = os.listdir(self.input_path) 
        return len(input_images)

    def __getitem__(self, idx):

        input = [f for f in os.listdir(self.input_path) if is_image(Path(f))]

        input_image = cv2.imread(self.input_path+input[idx]).astype(np.float32)
        input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)

        target = [f for f in os.listdir(self.target_path) if is_image(Path(f))]

        target_image = cv2.imread(self.target_path+target[idx]).astype(np.float32)
        target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB)

        # Transpose it into CxHxW PyTorch format
        input_image = np.transpose(input_image, (2, 0, 1)).astype(np.float64)
        target_image = np.transpose(target_image, (2, 0, 1)).astype(np.float64)

        input_image = input_image / 255.
        target_image = target_image / 255.

        return torch.FloatTensor(input_image), torch.FloatTensor(target_image)
ncilfone commented 2 years ago

ok... maybe a dumb question on my end but what is everything installed in this conda venv /anaconda/envs/py37_default/bin/python?

Updated: I only ask because the only place I can really find reference to an intercept_args function related to PyTorch's DataLoader is in older versions of FastAI (https://github.com/fastai/fastai1/blob/master/fastai/basic_data.py#L10) but I don't think that's in play here (it's weird that the CIFAR10 example ran fine but your code still wont...)

rushi-the-neural-arch commented 2 years ago

Ohh yes I do have fastai 1.0.58 installed in my environment which I sometimes use for experimentation purposes but we are not using it anywhere in the script right now, but could it still affect though?? My torch version is 1.10.0

ncilfone commented 2 years ago

I think that could be causing the issue... This line here is definitely being dynamically injected into core PyTorch. I'm not too sure how this will change imports etc. but it could be getting called somehow. Maybe try in a fresh venv?

In the old venv, one way to test this sholud be to look at the function signature for the init function for the DataLoader class you are importing. This should show if it's being dynamically overridden. Try something like this:

from torch.utils.data import DataLoader
import inspect

print(inspect.signature(DataLoader.__init__))

Note: intercept_args disappeared from fastai post v2.0 release which is why I think I can only find references to that function in issues from older versions (https://github.com/fastai/fastai/issues/2328). Maybe also try upgrading to fastai V2+ and see if that fixes it?

rushi-the-neural-arch commented 2 years ago

Ohh Ohkay so I was using FastAI modules for one of the custom loss functions that I am using and it seems that it might have injected this issue from there. While importing fastai, the function signature look like this

(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate at 0x7f99ca7f2170>, pin_memory=True, drop_last=False, timeout=0, worker_init_fn=None)

After removing all FastAI imports, the function signature looks like this -

(self, dataset: torch.utils.data.dataset.Dataset[+T_co], batch_size: Union[int, NoneType] = 1, shuffle: bool = False, sampler: Union[torch.utils.data.sampler.Sampler, NoneType] = None, batch_sampler: Union[torch.utils.data.sampler.Sampler[Sequence], NoneType] = None, num_workers: int = 0, collate_fn: Union[Callable[[List[~T]], Any], NoneType] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Union[Callable[[int], NoneType], NoneType] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False)

It worked!! @ncilfone Thank you so much!! This was the issue all along! I checked and removed everything except fastai version and it was such a silly mistake! but hard to find. I successfully ran the script for 2 epochs and it is working perfectly!

rushi-the-neural-arch commented 2 years ago

Also @ncilfone one small doubt, I couldn't find the documentation for the correct way to initialise a scheduler in Stoke, can you please let me know if this code snippet below is the best way to use schedulers??

stoke_optimizer= StokeOptimizer(
     optimizer = AdamW,
     optimizer_kwargs = {
         "lr" : 1e-3,          
         "betas" : (0.9, 0.99),
         "eps" : 1e-8,
         "weight_decay" : 1e-4        
     }

 )

stoke_model = Stoke(model, stoke_optimizer.......)

orig_optim = optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4)

scheduler = optim.lr_scheduler.OneCycleLR(orig_optim, max_lr=0.001, pct_start = 0.9, steps_per_epoch=len(train_dataloader), epochs=epochs)

train():
     .......
      stoke_model.step()
      scheduler.step()

We cant use the stoke_optimizer dictionary in the scheduler optimizer argument as it is incompatible and throws error like this TypeErrorTypeError: : TypeErrordict is not an Optimizerdict is not an Optimizer:

However, I doubt that the above method that I am using for scheduler might not be totally correct and could lead to different results. If that's the case, please let me know what would be the correct way to do it!

Sample code posted here

ncilfone commented 2 years ago

Hi @rushi-the-neural-arch

Glad we've finally solved this!!!

As for the scheduler. Right now Stoke doesn't help deal with the scheduler but there is an open issue to add that (#20). I'm still undecided if that's something stoke should support seeing as it's not related to 'accelerators' and can be done outside of the API currently. Feel free to comment there with your opinions.

Note: The stoke object has a bunch of properties that expose all the underlying objects, including the optimizer. See here. For some odd reason the docs are not picking up these @property methods....

The current way to deal with the scheduler would be:

stoke_optimizer= StokeOptimizer(
     optimizer = AdamW,
     optimizer_kwargs = {
         "lr" : 1e-3,          
         "betas" : (0.9, 0.99),
         "eps" : 1e-8,
         "weight_decay" : 1e-4        
     }

 )

stoke_model = Stoke(model, stoke_optimizer.......)

scheduler = optim.lr_scheduler.OneCycleLR(stoke_model.optimizer, max_lr=0.001, pct_start = 0.9, steps_per_epoch=len(train_dataloader), epochs=epochs)

train():
     .......
      ### PyTorch 1.10 -- they changed the order required
      stoke_model.step()
      scheduler.step() 

     ### PyTorch < 1.10
     ......
     scheduler.step()
     stoke_model.step()
ncilfone commented 2 years ago

Closing since we've finally resolved this!

rushi-the-neural-arch commented 2 years ago

Hi! Sure, ya thank you very much for all the help! I Will experiment more on Stoke, this is pretty helpful for distributed training!