Closed tmabraham closed 4 years ago
There is a fundamental difference in PyTorch/XLA vs TF/TPU paradigms. Whereas PT/TPU builds all the graphs, initializes the weights, runs input pipelines etc and then feeds the TPUs. TF/TPU builds the TF graphs, converts it into XLA graphs and hands it over to the TPU for doing all the heavy lifting.
Also, based on your Kaggle Kernel you posted, I assume that the SIGKILL was issued by due to RAM OOM, though we'd need to check kernel logs to know for sure (not the memory on the TPU core). @ifigotin may be working on bumping that limit but I'll let him chime in on that status.
@jysohn23 Thanks for your reply. So does this difference in paradigm lead to some models not working, either due to OOM problems or other problems? If so, why?
I would also note that Kaggle did give me a separate message saying:
Your notebook tried to allocate more memory than is available.
So you are probably right, I realize now it probably is VM RAM OOM. How can I reduce the memory usage in this notebook?
Can you try this?
!free -h
And this:
!cat /proc/cpuinfo | grep processor | wc -l
@tmabraham Yeah, PT/TPU sometime uses more RAM on GCE VM whereas TF/TPU uses on TPU VM. But as long as you can get more RAM GCE VM you should be fine.
@dlibenzi There is a total of 18 GB of memory:
total used free shared buff/cache available
Mem: 18G 866M 12G 1.0M 5.5G 17G
Swap: 0B 0B 0B
The output of the second command is 4
@jysohn23 Are there any steps I can take to reduce VM RAM usage in my notebook?
How big are the x_train
and x_valid
tensors?
@dlibenzi x_train is (435775,128) and y_train is (435775,1). Note the TF TPU kernel uses (435775,192).
That gets encoded AFAICT. Can you print the final x_train
shape?
@dlibenzi I didn't understand. This is the tokenized/encoded shape. There are 435775 sentences and when encoded the representation has max_len
equal to 128. This is the shape of the data that I use when creating my TensorDataset.
Just wanted to follow up on this... Is there any way this can be fixed? Or is this a limitation of PyTorch XLA?
Also I will note that I get similar OOM problems when using bert-base-cased, which is supposed to be a model of the same size. I haven't investigated this thoroughly, so don't know if it's the same issue or not,
@jysohn23 @dlibenzi Sorry to keep bothering you. Just wanted to see if this is a known issue or if there is something I am supposed to do in my environment to prevent this issue.
Hi @tmabraham sorry for the late reply. As far as I can tell, it's a limitation of the VM that kaggle provides. Can you try running the exact same on a Colab notebook? Make a copy off of this colab notebook sample we provide: https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb and paste in your content.
@jysohn23 Yes I was able to get something that seems to work in Colab.
@tmabraham Yeah, PT/TPU sometime uses more RAM on GCE VM whereas TF/TPU uses on TPU VM. But as long as you can get more RAM GCE VM you should be fine.
Are there steps we can take to reduce PT/TPU RAM usage, or is this an inherent limitation of PyTorch XLA?
So I ran your Kaggle notebook, and after tokenization, there are about 3GB left.
There are 11GB buffer cache, but the dataset seems pretty big.
To be clear it's not a limitation of pytorch/xla, but rather an imbalance in resources that are given out for free. On Kaggle they're granting a couple CPU cores and few GB ram to feed 8 TPU cores. You'd have the exact same problem if you were given 4 free V100 on Kaggle kernels with only couple CPU cores and few GB RAM. You can try creating model before forking processes as long as they're read only to reduce memory footprint caused by model weights.
@jysohn23 I still think this could be a limitation of PyTorch XLA because there is a TF kernel that works in Kaggle Kernels. Maybe there are some optimizations in TF that are not possible in PyTorch XLA?
I am creating the model before the forking processes.
PT/TPU vs TF/TPU is not an apples-to-apples comparison as they have different paradigms: https://github.com/pytorch/xla/issues/1870#issuecomment-608671471
@jysohn23 I understand that, but I guess I was hoping there is still something that could be done to prevent the higher usage of RAM with PyTorch XLA vs TF TPU. I guess the answer is no.
I will try to ask Kaggle if they can potentially increase the RAM for the VMs. If not, I will train in Colab. Thanks for the clarification!
I ran your Kaggle kernel trying to add stuff like: del df_train, df_valid
which freed up like 500 MBs, but it still looks like it runs out of memory right as we do model = mx.to(xm.xla_device())
in the xmp
context, since there we create 8 copies of the model after sending to device 😞.
We have made a change (which should be on nightly) that lowers the host memory utilization. @tmabraham mind giving it a try on nightly?
I tried myself with nightly and it trains:
https://www.kaggle.com/davidelibenzi/simple-xlmr-tpu-pytorch
The trick is adding --version nightly
to the env-setup script.
In that kernel: https://www.kaggle.com/davidelibenzi/simple-xlmr-tpu-pytorch
nprocs is set to 1 instead of 8.
I'm running into essentially the same issue in my own Kaggle kernel. Bert Base loads fine without OOM issues, however XLM-RoBERTa-Base does not. I don't even dare try XLM-RoBERTa-Large, which is able to load on TensorFlow as can be seen in this kernel: https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta
Have tried with 8 as well. With nightly, it trains. Rerunning now ...
So, we normally recommend this kind of structure:
def _mp_fn(...):
model = Net().to(xla_device)
...
xmp.spawn(_mp_fn, ..., start_method='fork')
But if the VM is RAM starved and the model are hefty parameter sizes, the recommended setup is more like:
# Create once at global scope to share pages with child processes.
model = Net()
def _mp_fn(...):
model.to(xla_device)
...
xmp.spawn(_mp_fn, ..., start_method='fork')
That, together with fork(2)
makes sure there will be only one copy of the models parameters on PyTorch CPU host memory.
Oh that's awesome. @dlibenzi can you commit the updated Kaggle Kernel? I don't see --version nightly on the notebook when I click your link.
Yeah, sorry. I had forgot to save 😄
https://www.kaggle.com/davidelibenzi/simple-xlmr-tpu-pytorch
I am running the code, and it roughly works 1/3 of the time. There is something weird going on...
I tried XLM-RoBERTa-Large based on @dlibenzi code with nprocs=8 ,it failed again but with nprocs=1,it trains.
I believe the issue comes from the following line:
xm.save(model.state_dict(), "xlm_roberta_model.bin")
Is there a way to move the model saving outside the _mp_fn
function?
In my experiments, the model_saving
surely doesn't break when we use 8 cores even when it's inside the _map_fn
;
I assume it depends on the size of the model, model.state_dict
seems to load the dict to the CPU memory of each of the 8 cores.
As for xm.save()
we have already optimized code to fetch TPU tensors to CPU only for the core which is going to actually do the save, but my guess is that at that point the memory is really tight after the 8 processes ran the training.
You can try this to see if it helps:
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'
As for
xm.save()
we have already optimized code to fetch TPU tensors to CPU only for the core which is going to actually do the save, but my guess is that at that point the memory is really tight after the 8 processes ran the training. You can try this to see if it helps:os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'
Thanks, will try that. There is no way to move it out to the main vm, right?
No, at the point every _mp_fn()
runs on a different process, so moving out would not change.
As a side note David, Why would VM RAM keep on increasing gradually during training loop? Even the VM has sufficient RAM and i am not saving any tensors etc as such?
Ref colab nbs
Also, Can SIGKILL happen because of TPU OOM? As I don't see the VM RAM Spike beyond 9-10 GB during training loop, not sure why it's a SIGKILL all of a sudden in mid of it;
Thanks;
I did some changes like creating properly size buckets, and creating tensors at global scope. Now training goes pretty fast and w/out OOM:
https://colab.research.google.com/drive/1pZRVafm_3wu1AKfU_W9S55nX7I-sch9q
Thanks @dlibenzi - did you also try it with roberta-large? With base model there are usually no issues, but large gives me a lot of trouble, also in terms of model not training (loss not decreasing).
@psinger I have xlm-roberta-large working in Kaggle Kernels and will release code soon.
Thanks @dlibenzi - did you also try it with roberta-large? With base model there are usually no issues, but large gives me a lot of trouble, also in terms of model not training (loss not decreasing).
I have changed the Colab above to use large, and it seems to be working. Though I have no idea what the loss curve should be, to call it good.
With BF16, it is OK. Using F32 leads to host OOM with roBERTa Large.
@dlibenzi Yes I have observed the same.
This load and trains (pretty fast) roBERTa Large.
Two major changes.
First I create file dataset composed by pytorch serialized records, one time.
This allows the training process to start cleaner, memory wise, as it does not require to load datasets, tokenizers, etc...
Second, I serialize the model.to(device)
call.
This seems to be working for me:
https://colab.research.google.com/drive/1k_fPaom7CymYMeIyfShP__zd0rDLLGGC
Never mind ... there were two places mentioning base/large, and I had changed only the tokenizer one. With this Colab it trains with BF16, but TPU OOM with F32:
https://colab.research.google.com/drive/1Yp9QGV7lWTSdGkTcnRreRVIn05_pDJyl
Will ping back about the status folk's! It's really been fun playing around TPUs and a nice test of patience and cool ideas as well 😅! Thanks again for this thread!
--------- EDIT 1
I can run the above colab successfully; I have updated the LR David, And the loss looks great to me post that! Just waiting for the eval results to appear;
First I create file dataset composed by pytorch serialized records, one time.
The way we got it working on kernels was by using a cached numpy pre-tokenized dataset loaded as TensorDataset; And it worked like a charm on both Colab and Kaggle! (effective bs was 128).
The initial attempt was to do tokenization etc on fly at batch-level but that had failed for some reason :( ; (Possibly because of batches of all len were used by me, but I see, you have them in the nearest multiple's of 2!)
Can you point me @dlibenzi where we can read more about the above serialized records to understand them better?
Thanks;
I am trying to train an XLM-R model in Kaggle Kernels with TPU enabled. There was a TF kernel that was able to do this successfully: https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta
However, attempts to train a similar model with PyTorch XLA have not been successful due to OOM errors. I tried to keep the code as similar as possible and made sure all non-XLA variables (dataset, model, etc.) were defined globally so it wasn't replicated 8 times. I am actually using a smaller model of the model (base vs large) and am using much lower batch sizes. I even tried using multi-threading interface (which is apparently now deprecated) as I read multi-threading uses less memory. In all cases I get OOM errors. In most cases, it will load the model do the forward function, but fail when calculating the loss function. In some cases, it fails at loss.backward().
I have two questions related to this: