alexhernandezgarcia / gflownet

Generative Flow Networks - GFlowNet
https://gflownet.readthedocs.io/en/latest/
Apache License 2.0
161 stars 10 forks source link

Simplify tfloat, tlong, tint, tbool... #312

Open alexhernandezgarcia opened 3 months ago

alexhernandezgarcia commented 3 months ago

Currently, the codebase uses helper methods (tfloat, tlong, tint, tbool) to convert numbers / lists / arrays into tensors with the corresponding dtype and send them to the right device. These methods are implemented in gflownet/utils/common.py.

For example, if we want to convert a batch of states into a float tensor, we can do the following:

states = tfloat(states, device=self.device, float_type=self.float)

While this has some advantages, it is rather annoying that we have to explicitly pass device and float_type, which end up making a pretty long line.

I wonder if there is a neat and simple way of changing things so that we could simply do

states = tfloat(states)
engmubarak48 commented 3 months ago

which end up making a pretty long line.

I wonder what is wrong with the long line? I think formatting can take care of that if the issue is only the long line, unless there is some other issue with it.

Another option is to make the tensors to be on device and float early on, and call states = tfloat(states).