Closed SummerYang98 closed 11 months ago
Hi! The neural networks modules are all in spark-sched-sim/schedulers/neural
, including decima.py
. I use a base class NueralScheduler
for these schedulers, which is defined in neural.py
in that same directory. Another important file is spark-sched-sim/wrappers/neural
, which converts the environment's observations into the models' formats, and converts their outputs into actions for the environment. All the code related to training these models (e.g. the PPO algorithm) is in the trainers
directory.
Thank you for your answer. I pay attention to how the parameters of decima's [NodeEncoder, DagEncoder, GlobalEncoder]'s parameters are trained. I tried running your default code. My understanding is: [NodeEncoder, DagEncoder, GlobalEncoder] are aggregated in ppo, and then during the ppo training process, the parameters of [NodeEncoder, DagEncoder, GlobalEncoder] will also be updated. At the same time, PPO and [NodeEncoder, DagEncoder, GlobalEncoder] are updated based on the same loss. Is my understanding correct? Thank you for providing such excellent code.
[NodeEncoder, DagEncoder, GlobalEncoder] are aggregated in ppo, and then during the ppo training process, the parameters of [NodeEncoder, DagEncoder, GlobalEncoder] will also be updated
I'm not sure what you mean by "aggregated in ppo", but you're right that all of these networks' parameters are updated during PPO. Note that they are all linked together: DagEncoder depends on NodeEncoder, and GlobalEncoder depends directly on both NodeEncoder and DagEncoder. The policy network - whose outputs are used to sample an action - depends on all three of them.
PPO and [NodeEncoder, DagEncoder, GlobalEncoder] are updated based on the same loss
PPO is not updated, as it's a procedure, not a model. You're correct that all the modules are updated based on the same loss; since all the modules are linked together, the loss depends on all of their parameters.
thank you. v
Hello, Where is the training code for graph neural network in decima?(NodeEncoder, DagEncoder, GlobalEncoder)