Open VzxPLnHqr opened 2 months ago
This does not seem right, but here is what I have come up with so far:
class NeuralNetwork extends nn.Module:
val flatten = nn.Flatten[Float64]()
// fixed architecture for now of fully connected tanh inner layers
val linearTanhStack = List.range(0,layerSizes.size - 1)
.foldLeft(List.empty[TensorModule[Float64]]){
case (layers, index) =>
val linear = nn.Linear[Float64](layerSizes(index), layerSizes(index+1))
val act = if index < layerSizes.size - 2 then nn.Tanh[Float64]() else nn.Identity[Float64]()
layers :+ linear :+ act
}.pipe(nn.Sequential[Float64](_*)).pipe(register(_))
def apply(x: Tensor[Float64]) =
val flattened = flatten(x)
val logits = linearTanhStack(flattened)
logits
val policyNet = NeuralNetwork()
val targetNet = NeuralNetwork()
def syncTargetNetToPolicyNet: IO[Unit] =
for
stateDict <- IO(policyNet.namedParameters(true).map((s,t) => (s,Tensor.fromNative[DType](t.native))))
yield targetNet.loadStateDict(stateDict)
Without doing the Tensor.fromNative[DType]
step, I could not get it to compile due to Tensor[?]
placeholder type in the signature for namedParameters
whereas loadStateDict
requires Map[String,Tensor[DType]
@VzxPLnHqr you're right these should be more consistent. Given that Tensor
is invariant currently, Tensor[DType]
is a bit awkward in general.
Looking at the implementation of loadStateDict
, I think we should be able to change the argument to stateDict: Map[String, Tensor[?]]
.
Could you perhaps try to use something like this loadStateDict
function as a workaround?
def loadStateDict(m: Module, stateDict: Map[String, Tensor[?]]): Unit =
val tensorsToLoad = m.namedParameters() ++ m.namedBuffers()
for ((key, param) <- tensorsToLoad if stateDict.contains(key))
noGrad {
param.copy_(stateDict(key))
}
If it works for you, I think we could also change the signature of the actual loadStateDict
method in Module
.
@sbrunk Thank you for your reply. I added the workaround loadStateDict
method you provided as an extension method to nn.Module
(wrapped in IO
since I am using cats-effect
):
extension(m: nn.Module)
def loadStateDict(stateDict: Map[String, Tensor[?]]): IO[Unit] = IO {
val tensorsToLoad = m.namedParameters() ++ m.namedBuffers()
for ((key, param) <- tensorsToLoad if stateDict.contains(key))
noGrad {
param.copy_(stateDict(key))
}
}
I can then use the new method like so:
targetNet.loadStateDict(policyNet.namedParameters(true).toMap)
Unfortunately the above gives a bloop/compile error:
Cannot prove that (String, torch.Tensor[?]) <:< (String, V2).
where: V2 is a type variable with constraint <: torch.Tensor[torch.DType]
update: turns out the .toMap
I had to append there was causing the compilation issue. Changing the signature for loadStateDict(stateDict: Map[String,Tensor[?]]: IO[Unit]
to loadStateDict(stateDict: SeqMap[String,Tensor[?]]):IO[Unit]
at least fixed that.
Have not had a chance yet to test the actual functionality though. Will do that next and let you know if it works as expected. If so, then yes, this seems like it would be a good usability improvement.
@sbrunk Thank you for this excellent library! I have been trying to re-implement this cart-pole deep q learning using
storch
andcats-effect
. In that article, there is the following python code which initializes the two networks:Every so often, as the
policy_net
gets trained, thetarget_net
needs to be updated.However, I cannot seem to find a way to access
stateDict
instorch
. Is there a more scala/storch-recommended way to take one network (thetarget_net
in the example above) and load it up so that it initially is equivalent to a different network (thepolicy_net
)?I noticed that in storch, the
loadStateDict
method is available, but just cannot figure out what to feed into it. Any help is much appreciated. Thanks!