logicalclocks / maggy

Distribution transparent Machine Learning experiments on Apache Spark
https://maggy.ai
Apache License 2.0
89 stars 14 forks source link

Add Petastorm/Parquet support, refactor Server #91

Closed amacati closed 3 years ago

amacati commented 3 years ago

Adds support for Petastorm/Parquet files as datasets and refactors the Server code.

Dataloader change (patching.py):

Server change (rpc.py): The server previously contained the message processing logic for all experiment types. The refactor aims for increased modularity in the RPC communication. It makes the addition of possible future experiments easy without adding further complexity to the base server.

moritzmeister commented 3 years ago

I have to admit, I do not 100% follow the PetaStorm Dataloader, why you need to traverse through the data.

amacati commented 3 years ago

Are you talking about the _to_cuda(data) function? Petastorm for example returns the loaded Parquet as dicts of torch.Tensors, and these tensors need to be transferred to the GPU in order to work with DDP. Same goes for custom DataLoader classes from users, sometimes they might return data as dicts or lists, or even nested lists etc. In that case the wrapper recurses into these lists/dicts and moves the underlying torch.Tensors to the GPU. An example: Say you have a custom DataLoader that returns the x and y values of some function as input and z values as targets. The DataLoader does not have any .to(device) calls because it was written in a local notebook and executed on the CPU. Then this code still works in Maggy, because of the _to_cuda call.

for data in train_loader:
    x, y, z = data['x'], data['y'], data['z']   
    inputs = torch.stack((x,y)).T  
    outputs = model(inputs)  # Model is on GPU
    loss = criterion(outputs, z.unsqueeze(-1))  

But it's not really traversing the whole data but rather the structure of the output from the DataLoader. There should be no associated performance hit, you'd have to issue the .cuda() calls anyways to get DDP to work.