The split branch does not support parallel training, because when the branch is instantiated, some unnecessary parts are instantiated, but these parts do not participate in the calculation of loss, and pytorch will throw an exception in this case.
If you need to implement parallel computing, you must refactor the code. This may not be painful, but it will take a lot of time.
So a good way is to train two branches, and then save the split results during inference. According to my experiments, dual-branch optimization does not affect single-task results very much. Good luck!
The split branch does not support parallel training, because when the branch is instantiated, some unnecessary parts are instantiated, but these parts do not participate in the calculation of loss, and pytorch will throw an exception in this case. If you need to implement parallel computing, you must refactor the code. This may not be painful, but it will take a lot of time. So a good way is to train two branches, and then save the split results during inference. According to my experiments, dual-branch optimization does not affect single-task results very much. Good luck!