sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
118 stars 7 forks source link

how to access `stateDict`? #78

Open VzxPLnHqr opened 2 months ago

VzxPLnHqr commented 2 months ago

@sbrunk Thank you for this excellent library! I have been trying to re-implement this cart-pole deep q learning using storch and cats-effect. In that article, there is the following python code which initializes the two networks:

        self.policy_net = self.build_network(layer_sizes)
        self.target_net = self.build_network(layer_sizes)
        self.target_net.load_state_dict(self.policy_net.state_dict())

Every so often, as the policy_net gets trained, the target_net needs to be updated.

However, I cannot seem to find a way to access stateDict in storch. Is there a more scala/storch-recommended way to take one network (the target_net in the example above) and load it up so that it initially is equivalent to a different network (the policy_net)?

I noticed that in storch, the loadStateDict method is available, but just cannot figure out what to feed into it. Any help is much appreciated. Thanks!

VzxPLnHqr commented 2 months ago

This does not seem right, but here is what I have come up with so far:

    class NeuralNetwork extends nn.Module:
      val flatten = nn.Flatten[Float64]()
      // fixed architecture for now of fully connected tanh inner layers
      val linearTanhStack = List.range(0,layerSizes.size - 1)
        .foldLeft(List.empty[TensorModule[Float64]]){
          case (layers, index) =>
            val linear = nn.Linear[Float64](layerSizes(index), layerSizes(index+1))
            val act = if index < layerSizes.size - 2 then nn.Tanh[Float64]() else nn.Identity[Float64]()
            layers :+ linear :+ act
        }.pipe(nn.Sequential[Float64](_*)).pipe(register(_))

      def apply(x: Tensor[Float64]) =
        val flattened = flatten(x)
        val logits = linearTanhStack(flattened)
        logits

    val policyNet = NeuralNetwork()
    val targetNet = NeuralNetwork()

    def syncTargetNetToPolicyNet: IO[Unit] =
      for
        stateDict <- IO(policyNet.namedParameters(true).map((s,t) => (s,Tensor.fromNative[DType](t.native))))
      yield targetNet.loadStateDict(stateDict)

Without doing the Tensor.fromNative[DType] step, I could not get it to compile due to Tensor[?] placeholder type in the signature for namedParameters whereas loadStateDict requires Map[String,Tensor[DType]

sbrunk commented 2 months ago

@VzxPLnHqr you're right these should be more consistent. Given that Tensor is invariant currently, Tensor[DType] is a bit awkward in general.

Looking at the implementation of loadStateDict, I think we should be able to change the argument to stateDict: Map[String, Tensor[?]].

https://github.com/sbrunk/storch/blob/2dfa3884b9f0f2d1e2566aad791f44535b48bb09/core/src/main/scala/torch/nn/modules/Module.scala#L54-L60

Could you perhaps try to use something like this loadStateDict function as a workaround?

def loadStateDict(m: Module, stateDict: Map[String, Tensor[?]]): Unit =
  val tensorsToLoad = m.namedParameters() ++ m.namedBuffers()
  for ((key, param) <- tensorsToLoad if stateDict.contains(key))
    noGrad {
      param.copy_(stateDict(key))
    }

If it works for you, I think we could also change the signature of the actual loadStateDict method in Module.

VzxPLnHqr commented 2 months ago

@sbrunk Thank you for your reply. I added the workaround loadStateDict method you provided as an extension method to nn.Module (wrapped in IO since I am using cats-effect):

extension(m: nn.Module)
  def loadStateDict(stateDict: Map[String, Tensor[?]]): IO[Unit] = IO {
    val tensorsToLoad = m.namedParameters() ++ m.namedBuffers()
    for ((key, param) <- tensorsToLoad if stateDict.contains(key))
      noGrad {
        param.copy_(stateDict(key))
      }
  }

I can then use the new method like so:

targetNet.loadStateDict(policyNet.namedParameters(true).toMap)

Unfortunately the above gives a bloop/compile error:

Cannot prove that (String, torch.Tensor[?]) <:< (String, V2).

where:    V2 is a type variable with constraint <: torch.Tensor[torch.DType]
VzxPLnHqr commented 2 months ago

update: turns out the .toMap I had to append there was causing the compilation issue. Changing the signature for loadStateDict(stateDict: Map[String,Tensor[?]]: IO[Unit] to loadStateDict(stateDict: SeqMap[String,Tensor[?]]):IO[Unit] at least fixed that.

Have not had a chance yet to test the actual functionality though. Will do that next and let you know if it works as expected. If so, then yes, this seems like it would be a good usability improvement.