tkn-tub / ns3-gym

ns3-gym - The Playground for Reinforcement Learning in Networking Research
GNU General Public License v2.0
524 stars 198 forks source link

Casting Box Datatype #78

Closed haidlir closed 1 year ago

haidlir commented 1 year ago

I tried to integrate ns3gym with stable-baselines. This error was found when the learning process started. The reason was clear enough.

Problem: ** Iteration 0 **** Traceback (most recent call last): File "ns3_stable_solve.py", line 101, in model.learn(total_timesteps=512) File "/home/haidlir/ns3/venv/lib/python3.7/site-packages/stable_baselines/ppo1/pposgd_simple.py", line 238, in learn seg = seg_gen.next() File "/home/haidlir/ns3/venv/lib/python3.7/site-packages/stable_baselines/trpo_mpi/utils.py", line 57, in traj_segmentgenerator action, vpred, states, = policy.step(observation.reshape(-1, *observation.shape), states, done) AttributeError: 'google.protobuf.pyext._message.RepeatedScalarConta' object has no attribute 'reshape'

Solution: I recommend casting the variable to np.array just before the return line as in https://github.com/haidlir/ns3-gym/blob/da37d708afa1214078127c4718f0a95be4627122/src/opengym/model/ns3gym/ns3gym/ns3env.py#L258

Moreover, np.array is everywhere as datatype in ML world. Thanks.

pgawlowicz commented 1 year ago

@haidlir Thanks for this patch!