Open fladventurerob opened 1 year ago
torch.compile does not emit a model in the way you expect it to, and I think maybe some new documentation has led you astray.
If you intend to pickle the exported model, give export a try. See the export section of https://pytorch.org/get-started/pytorch-2.0/
it sounds like @fladventurerob is trying to compile the model after loading it, not before exporting it.
however, @fladventurerob it would be helpful if you can provide a runnable script for us to look at, rather than just a description.
it might be because of multiprocess compile. @fladventurerob can you give this a try:
import torch
, add these two lines: from torch._inductor import config
config.compile_threads = 1
torch.compile
call.Does the error reproduce?
it sounds like @fladventurerob is trying to compile the model after loading it, not before exporting it.
however, @fladventurerob it would be helpful if you can provide a runnable script for us to look at, rather than just a description.
You are correct. The model exists already. I was loading it into a forward test script. Based upon the documentation I was assuming this line was needed to use an existing model, rather than to export the model.
@fladventurerob any update on whether compile_threads=1 helps, or are you able to provide a repro script for us?
I had the same issue -- setting compile_threads=1
fixed it for me!
For what it's worth, I'm using triton
at HEAD
and this isn't an issue I run into any more (I no longer need to specify compile_threads=1
) when I built today.
I haven't followed the changes made very closely recently, but I did see this merged today (though perhaps unrelated): https://github.com/openai/triton/pull/1133.
I would not close this issue. It randomly appears when you are going to torch save a model that needs a lot of time to be trained, which means that nothing is saved. Could you update then __getstate__
of that model internally, so that any multithreading related pickling issues won't occur? I find this to be a bug of torch.compile()
I believe this one will be fixed by https://github.com/pytorch/pytorch/pull/101651 when it lands
I met the same bug, here is the simplest test case I can offer:
import torch
from torch import nn
class CNNModel(nn.Module):
def __init__(
self,
class_number=10,
input_channel=3,
dropout=0.1,
kernel_sizes=[5, 3, 3],
paddings=[2, 1, 1],
hidden_dims=[32, 32, 32]
):
super(CNNModel, self).__init__()
self.layers = []
self.layers.append(nn.Conv2d(input_channel, hidden_dims[0], kernel_size=kernel_sizes[0], padding=paddings[0]))
self.layers.append(nn.ReLU())
for i in range(len(hidden_dims) - 1):
self.layers.append(
nn.Conv2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=kernel_sizes[i + 1], padding=paddings[i + 1])
)
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(p=dropout))
self.layers = nn.Sequential(*self.layers)
self.glp = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Sequential(nn.Flatten(), nn.Linear(hidden_dims[-1], class_number))
def forward(self, x):
x = self.layers(x)
x = self.glp(x)
x = self.fc(x)
return x
if __name__ == '__main__':
import pickle
model = CNNModel()
from torch._inductor import config
config.compile_threads = 1
model = torch.compile(model)
pickle.dumps(model)
The error is:
Traceback (most recent call last):
File "test", line 46, in <module>
pickle.dumps(model)
AttributeError: Can't pickle local object 'convert_frame.<locals>._convert_frame'
Maybe this information helps :)
I'm having a similar issue, but with dataloaders with num_workers > 0.
(In particular, when using spawn instead of fork, I think this is probably cause it copies the entire environment)
https://github.com/pytorch/pytorch/issues/101107 seems relevant.
config.compile_threads = 1
doesn't fix the error for me with 2.1.1.
Hello! I encounter this issue while compiling a transformation used within the multi-processing context of the data loader.
Even though I intend for the transformation to run on the CPU, the GPU is detected leading to the following error:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
Despite attempting to investigate the issue by commenting out calls to torch.cuda.get_rng_state()
and torch.cuda.set_rng_state()
in torch/_dynamo/convert_frame.py::wrap_convert_context
and torch/_dynamo/utils.py::preserve_rng_state
, the problem persists. I suspect that these two functions are called elsewhere (possibly in the _inductor
or the compiled code itself). I'm happy to turn this into a dedicated issue.
In an effort to resolve this error, I followed the recommendation and utilized spawn as the multiprocessing context for the DataLoader
. However, this led to the error mentioned in this issue:
AttributeError: Can't pickle local object 'convert_frame.<locals>._convert_frame'
While preparing a pull request, I attempted to move the definition of _convert_frame()
outside of convert_frame()
. This introduced a new challenge because two attributes are added to the function which change the object and breaks the serialization:
_pickle.PicklingError: Can't pickle <function _convert_frame at 0x7f64e40a5ea0>:
it's not the same object as torch._dynamo.convert_frame._convert_frame
So unassigning myself from this because i couldn't justify spending more time, I got stuck with getting 3 final tests to pass here https://github.com/pytorch/pytorch/pull/101651
The core idea is just to wrap some functions into classes so they become picklable
@anijain2305 to provide update
Possible partial fix:
AttributeError: Can't pickle local object 'TrainAugmentation.init.
@anthai0908 https://github.com/qfgaohao/pytorch-ssd/issues/71 seems more relevant to you. I think your comment is really long and unrelated. It might be nice to remove it as to not clutter this issue.
thanks @ringohoffman. After training and exporting to onnx, I have one question, is it possible to deploy inferenceon python 3.6.9 with onnx format?
@anijain2305 still looking to work on this soon.
Update from triage meeting: @williamwen42 thought this might have been fixed in another PR??
Can't repro what @kxzxvbk provided above, on 9c2c61d2dd.
I am unable to repro the issue here. I will keep the issue open for a month in case someone comes up with a better repro.
🐛 Describe the bug
When adding the line:
model = torch.compile(model)
after loading the model, this error occurs. When removing the line, the script functions as intended.Error logs
File "/opt/anaconda3/envs/ml1/lib/python3.8/multiprocessing/process.py", line 121, in start self._popen = self._Popen(self) File "/opt/anaconda3/envs/ml1/lib/python3.8/multiprocessing/context.py", line 224, in _Popen return _default_context.get_context().Process._Popen(process_obj) File "/opt/anaconda3/envs/ml1/lib/python3.8/multiprocessing/context.py", line 284, in _Popen return Popen(process_obj) File "/opt/anaconda3/envs/ml1/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in init super().init(process_obj) File "/opt/anaconda3/envs/ml1/lib/python3.8/multiprocessing/popen_fork.py", line 19, in init self._launch(process_obj) File "/opt/anaconda3/envs/ml1/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch reduction.dump(process_obj, fp) File "/opt/anaconda3/envs/ml1/lib/python3.8/multiprocessing/reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) AttributeError: Can't pickle local object 'convert_frame.._convert_frame'
Minified repro
No response
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @Xia-Weiwen @ipiszy @soumith @ngimel