sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
117 stars 7 forks source link

Implement "pico" GPT example #51

Open hmf opened 1 year ago

hmf commented 1 year ago

Response to request in issue https://github.com/sbrunk/storch/issues/44.

Attempt to rewrite the "pico" example from Karpathy's "Let's build GPT: from scratch, in code, spelled out" in storch.

hmf commented 1 year ago

@sbrunk or anyone else. I need some assistance in this work . To code the "pico" example, I need the Embedding operator. In my branch I have added this here. I have also added comments and made sure ScalaDoc is ok (minus the math expressions).

The code I am working on now, is the BiGram class. If I understand the code correctly, I have to pass a Tensor of shape/size (B,T) and get a Float back. According to the native code that seems to be a call to the forward method. So I am using this in the embedding class as per the other modules:

  def apply(t: Tensor[Int64]): Tensor[D] = Tensor(nativeModule.forward(t.native))

And this is a problem because I get the error:

[error] 101 |final class Embedding[D <: DType: Default](
[error]     |            ^
[error]     |class Embedding needs to be abstract, since def apply(v1: T1): R in trait Function1 in package scala is not defined 
[error]     |(Note that
[error]     | parameter T1 in def apply(v1: T1): R in trait Function1 in package scala does not match
[error]     | parameter torch.Tensor[torch.Int64] in def apply(t: torch.Tensor[torch.Int64]): torch.Tensor[D] in class Embedding in package torch.nn.modules.embed
[error]     | )

I think this is because we extend from TensorModule :

trait TensorModule[D <: DType] extends Module with (Tensor[D] => Tensor[D]):

In other words, the apply from (Tensor[D] => Tensor[D]) assumes the input and output are of the same type. Do we have other operators were this is not true? If not, how should we handle this?

On a related note, is it possible to constrain the Tensor by its shape?

TIA

hmf commented 1 year ago

In order to keep going I have used the following solution:

  def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
  @targetName("apply_T_D")
  def apply[T<:DType](t: Tensor[T]): Tensor[D] = Tensor(nativeModule.forward(t.native))

Is this ok for a final solution?

sbrunk commented 1 year ago

@hmf You're right Embedding is an example where the input type might be different from the output, so we can't inherit from TensorModule.

Note that @davoclavo has also added Embedding and a few other modules in #36 (haven't been able to finish and merge that yet, unfortunately) and added a more generic TensorModuleBase to tackle this issue:

https://github.com/sbrunk/storch/blob/05f7dbdca35daa0589447ad0d4eadbefe38e1aeb/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala#L58-L68

https://github.com/sbrunk/storch/blob/05f7dbdca35daa0589447ad0d4eadbefe38e1aeb/core/src/main/scala/torch/nn/modules/Module.scala#L125-L127

So eventually we need to merge your solutions but for now you could also just inherit from nn.Module and add then use your apply method:

  def apply[T<:DType](t: Tensor[T]): Tensor[D] = Tensor(nativeModule.forward(t.native))

On a related note, is it possible to constrain the Tensor by its shape?

Right now, we're tracking only the dtype at compile time. We might add that in the future though.

hmf commented 1 year ago

@sbrunk I have looked at the embedding class and my version its pretty close to it. Currently cannot search @davoclavo's branch, but I think I can copy and use that code (minimum set of classes with updated docs). Might be easier on your side.

In the meantime if you do merge into the main branch, I will update accordingly. Ok, with you?

sbrunk commented 1 year ago

Sounds good to me πŸ‘

hmf commented 1 year ago

Question about cross entropy functions. IThe orgial code uses something like:

import torch
import torch.nn as nn
from torch.nn import functional as F

...
loss = F.cross_entropy(logits, targets)
...
            probs = F.softmax(logits, dim=-1) # (B, C)

I see that we have 2 options, a function in the Loss package (does not exist yet, only binary version available) and the torch.nn.loss.CrossEntropyLoss version. The storch examples use the latter.

What are the advantages/disadvantages of using one or the other?

sbrunk commented 1 year ago

I see that we have 2 options, a function in the Loss package (does not exist yet, only binary version available) and the torch.nn.loss.CrossEntropyLoss version. The storch examples use the latter.

What are the advantages/disadvantages of using one or the other?

PyTorch has a functional and a class/module variant for most of its nn operations. See torch.nn.functional.cross_entropy and torch.nn.CrossEntropyLoss. The class variant usually inherits from Module to it's easy to put it into containers expecting modules.

The functional variant does not contain any state, you call it directly with the tensor inputs and other arguments. The class/module variant can be initialized first with init parameters, and then later reused for different inputs. If you have modules with learnable weights/parameters, the module variant also helps you manage that state (makes it easier to update all weights of your model etc.).

For stateless ops without weights, like cross_entropy the class variant doesn't have much advantage except for reuse, so you can also just use the functional variant but it doesn't make much of a difference after all.

davoclavo commented 1 year ago

Hello @hmf! awesome work on implementing Karpathy's examples. I have done some progress as well, but last month I got sidetracked with some things at work so wasn't able to prepare the code to share it.

I'll leave my progress implementing some of the model building blocks here in case it is helpful in any way to you. As @sbrunk mentioned, there are some new modules implemented in PR #36 - such as Embedding, LayerNorm, ModuleList, etc. - and this code expects those modules to exist in storch.

(Btw, you should be able to access my branch from via the PR, or via this direct link)

final case class Head[D <: FloatNN: Default](
    numEmbeddings: Int,
    headSize: Int,
    blockSize: Int,
    dropoutProb: Float
) extends TensorModule[D] {
  val query = register(nn.Linear(numEmbeddings, headSize))
  val key = register(nn.Linear(numEmbeddings, headSize))
  val value = register(nn.Linear(numEmbeddings, headSize))
  val tril = register(torch.tril(torch.ones(Seq(blockSize, blockSize))))
  val dropout = register(Dropout(dropoutProb))

  override def apply(input: Tensor[D]): Tensor[D] =
      val Seq(batch, timeStep, channels) = input.shape // (B, T, C) (64, 256, 384) [Float32]
      assert(blockSize == timeStep, "Block size must be equal to time step")

      val k: Tensor[D] = key(input) // (64, 256, 64) [Float32]
      val q: Tensor[D] = query(input) // (64, 256, 64) [Float32]
      val v: Tensor[D] = value(input) // (64, 256, 64) [Float32]

      // TODO Get rid of the `.to(dtype = q.dtype)`
      val weight =
        torch.matmul(q, torch.transpose(k, -2, -1)) / Tensor(Math.sqrt(channels)).to(dtype = q.dtype) // (64, 256, 256) [Float32]
      val weightMasked =
        weight.maskedFill(
          tril(Slice(0, timeStep), Slice(0, timeStep)) == 0,
          Float.NegativeInfinity
        ) // (64, 256, 256) [Float32]
      val attention =
        torch.nn.functional.softmax(weightMasked, dim = 2)(
          weightMasked.dtype
        ) // (64, 256, 256) [Float32]
      val attentionDropout = dropout(attention) // (64, 256, 256) [Float32]
      val output = weight.matmul(v) // (64, 256, 64) [Float32]
      output
}

final case class MultiHeadAttention[D <: FloatNN: Default](
    numHeads: Int,
    numEmbeddings: Int,
    headSize: Int,
    blockSize: Int,
    dropoutProb: Float
) extends TensorModule[D] {
  // Multiple heads of self-attention in parallel

  val heads = register(nn.ModuleList(Range(0, numHeads).map { _ =>
    Head[D](numEmbeddings, headSize, blockSize, dropoutProb)
  }*))
  val projection = register(nn.Linear(numHeads * headSize, numEmbeddings))
  val dropout = register(Dropout(dropoutProb))
  override def apply(input: Tensor[D]): Tensor[D] =
      val headOutputs = heads.map { head =>
        head(input)
      } // (6, 64, 256, 384) [Float32]
      val headOutputsConcat = torch.cat(headOutputs, dim = -1) // (64, 256, 384) [Float32]
      val projectedOutput = projection(headOutputsConcat) // (64, 256, 384) [Float32]
      dropout(projectedOutput) // (64, 256, 384) [Float32]
}

final case class FeedForward[D <: FloatNN: Default](numEmbeddings: Int, dropoutProb: Float)
    extends TensorModule[D] {
  // A simple linear layer followed by a non-linearity

  val net = register(nn.Sequential(
    nn.Linear(numEmbeddings, numEmbeddings * 4),
    nn.ReLU(),
    nn.Linear(numEmbeddings * 4, numEmbeddings),
    Dropout(dropoutProb)
  ))
  override def apply(input: Tensor[D]): Tensor[D] =
    net(input)

}

final case class Block[D <: FloatNN: Default](numEmbeddings: Int, numHeads: Int, blockSize: Int, dropoutProb: Float)
    extends TensorModule[D] {
  // Transformer block: communication followed by computation
  val headSize = numEmbeddings / numHeads // 384 / 6 = 64
  val attention = register(MultiHeadAttention(numHeads, numEmbeddings, headSize, blockSize, dropoutProb))
  val feedForward = register(FeedForward(numEmbeddings, dropoutProb))
  val layerNorm1 = register(nn.LayerNorm(Seq(numEmbeddings)))
  val layerNorm2 = register(nn.LayerNorm(Seq(numEmbeddings)))

  override def apply(input: Tensor[D]): Tensor[D] =
      // (64, 256, 384) [Float32]
      val a = input + attention(layerNorm1(input)) // (64, 256, 384) [Float32]
      val b = a + feedForward(layerNorm2(a)) // (64, 256, 384) [Float32]
      b

}

final case class Dropout[D <: FloatNN: Default](probability: Float) extends TensorModule[D] {
  override def apply(x: Tensor[D]): Tensor[D] =
    nn.functional.dropout(x, probability)
}

I'm happy to assist you in any way to get this to work. I was able to get some inference going without any runtime errors, but haven't had time to train the model using shakespeare writings yet.

I will also be available to continue work on the pending PR to get it merged, in case I can help in any way @sbrunk

davoclavo commented 1 year ago

Oh I forgot, there are also some changes needed for pico GPT that I haven't created a PR for, but I have fixed in my local project. I aim to get these changes submitted soon, but here they are in case you need them earlier:

Tensor#maskedFill

def maskedFill[S <: ScalaType](mask: Tensor[Bool], value: S): Tensor[D] = Tensor(
  native.masked_fill(mask.native, toScalar(value))
)

Tensor#sqrt

def sqrt = Tensor(native.sqrt())

torch.tril

  def tril[D <: DType](input: Tensor[D], diagonal: Int = 0): Tensor[D] =
    Tensor(torchNative.tril(input.native, diagonal.toLong))

Fixing tensor.split (see #39)

  def split[D <: DType](
      input: Tensor[D],
      splitSizeOrSections: Int | Seq[Int],
      dim: Int = 0
  ): Seq[Tensor[D]] = {
    val result =
      splitSizeOrSections match {
        case i: Int      => torchNative.split(input.native, i.toLong, dim.toLong)
        case s: Seq[Int] => torchNative.split(input.native, s.map(_.toLong).toArray, dim.toLong)
      }
    (0L until result.size()).map(i => Tensor(result.get(i)).clone())
  }
sbrunk commented 1 year ago

I will also be available to continue work on the pending PR to get it merged, in case I can help in any way @sbrunk

@davoclavo feel free to take over #36 again if you have capacity. I've merged main into it with some improvements of the native bindings but since Scala Days is only 4 weeks away I'd like to focus on getting my Storch talk ready first. Happy to help/review etc. but I'm not sure I'll be able to actually work on it before the talk.

davoclavo commented 1 year ago

@sbrunk sounds good, I'll try to polish the last remaining bits.

Best of luck on the Scala Days talk! Hopefully it will be streamed/recorded, I'd love to watch it :D

sbrunk commented 1 year ago

Best of luck on the Scala Days talk! Hopefully it will be streamed/recorded, I'd love to watch it :D

Thanks! I'm sure it will be recorded and put on youtube some time after the conference as the videos from the Seattle edition from June are already online. I'll keep you posted :)

hmf commented 1 year ago

@davoclavo Thanks for the assist. Please note that at this time I am working on the very simple "video" version. My aim here is to learn about GPT.

I will look at your code and incorporate all I can to make merging easier.

hmf commented 1 year ago

Questions regarding softmax. I was coding the cross_entropy examples to make sure the typing is correct. In the second example we need the softmax function in the link below. Looking at the code I see we have:

  def softmax[In <: DType, Out <: DType](input: Tensor[In], dim: Long)(
      dtype: Out = input.dtype
  ): Tensor[Out] =
    val nativeDType =
      if dtype == input.dtype then ScalarTypeOptional() else ScalarTypeOptional(dtype.toScalarType)
    Tensor(torchNative.softmax(input.native, dim, nativeDType))

This means that we have explicitly provide the last (usually empty) parameter so:

  val target1 = F.softmax( input=torch.randn(Seq(3, 5)), dim=1L)()

If we don't, we get the error:

[error] 358 |  val loss1 = F.crossEntropy(input1, target1)
[error]     |                                     ^^^^^^^
[error]     |Found:    (gpt.BiGram.target1 : torch.DType => torch.Tensor[torch.DType])
[error]     |Required: torch.Tensor[O]
[error]     |
[error]     |where:    O is a type variable with constraint <: torch.NumericRealNN

I have made that last parameter an implicit. I did the same for logSoftmax. If we do this, we avoid having to provide that last parameter. It seems that only the softmax call was used. Ran the test, had no problem. Ok, with this change or am I missing something?

The original Python example code uses a Tensor.softmax(dim=1) call. This method does not exist in storch. The Python documentation states that it is an "Alias for torch.nn.functional.softmax()." Should we add this? If so, do we add as a standard method or use use Scala 3 extension methods?

TIA

sbrunk commented 1 year ago

I have made that last parameter an implicit. I did the same for logSoftmax. If we do this, we avoid having to provide that last parameter. It seems that only the softmax call was used. Ran the test, had no problem. Ok, with this change or am I missing something?

That's fine but could you give the following variant a try? It's a solution we already use in other places and avoids both implicits and multiple parameter lists (at the expense of a slightly more verbose type signature).

import Derive.derive

// ...

  def softmax[In <: DType, Out <: FloatNN | Derive](
      input: Tensor[In],
      dim: Long,
      dtype: Out = derive
  ): Tensor[DTypeOrDeriveFromTensor[In, Out]] =
    val derivedDType = dtype match
      case _: Derive => input.dtype
      case d: DType  => d
    val nativeDType =
      if dtype == input.dtype then ScalarTypeOptional()
      else ScalarTypeOptional(derivedDType.toScalarType)
    Tensor(torchNative.softmax(input.native, dim, nativeDType))
}

The original Python example code uses a Tensor.softmax(dim=1) call. This method does not exist in storch. The Python documentation states that it is an "Alias for torch.nn.functional.softmax()." Should we add this? If so, do we add as a standard method or use use Scala 3 extension methods?

Yes, you can add it as a regular method in Tensor delegating to the implementation in nn.functional

hmf commented 1 year ago

That's fine but could you give the following variant a try? It's a solution we already use in other places and avoids both implicits and multiple parameter lists (at the expense of a slightly more verbose type signature).

Done (also for logSoftmax). Compiled and all tests pass.

Yes, you can add it as a regular method in Tensor delegating to the implementation in nn.functional

Done:

  def shape: Seq[Int] = size

  def softmax[Out <: FloatNN | Derive](
      dim: Long,
      dtype: Out = derive
  ): Tensor[DTypeOrDeriveFromTensor[D, Out]] = F.softmax(input = this, dim = dim, dtype = dtype)

  def square = Tensor(native.square())
hmf commented 1 year ago

While trying to replicate the Colaboratory notebook to check the code is working, I tried to do the following:

  // We want x[b,t] = mean_{i<=t} x[b,i]
  val xbow = torch.zeros(Seq(b0, t0, c0))
  for b <- 0 until b0
  do
    for t <- 0 until t0
    do
      val xprev = x(b,ΒΊ`:`t+1) // (t,C)
      xbow(b,t) = torch.mean(xprev, 0)  

The Tensorclass has no assignment operator. I also did not find a method for this in the JavaCPP code. How should one go about assigning a value?

TIA

sbrunk commented 1 year ago

The Tensorclass has no assignment operator. I also did not find a method for this in the JavaCPP code. How should one go about assigning a value?

The C++ API has a method for assigning values (with indices): See https://pytorch.org/cppdocs/notes/tensor_indexing.html#setter It's just not that easy to find, because it's named index_put_. It's also mapped via JavaCPP, but was missing in Storch.

https://github.com/sbrunk/storch/pull/53 should add support for it. Could you give it a try?

hmf commented 1 year ago

Found some compiler weirdness with the changes above.These do not compile:

      xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0)
      xbow(Seq(b,t)) = torch.mean(xprev, dim=0)  

The error is:

method mean in trait ReductionOps: (input: torch.Tensor[?], dtype: torch.Float32): torch.Tensor[torch.Float32] does not have a parameter dim

and (for the last one):

Found:    (0 : Int)
Required: torch.Float32

But these do:

      xbow(b,t) += torch.mean(xprev, dim=0)  
      val c = torch.mean(xprev, dim=0) 
      xbow(Seq(b,t)) = c
      xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0, true, float32)
      xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0, true)

Maybe some tweaking of the 1st definition may get it working, but seems like a Scala issue.

sbrunk commented 1 year ago

It looks like the compiler gets confused by the overloaded variants of mean for whatever reason. I've seen this in other places with different generic overloads.

I realized that the default dim argument with an empty seq defaults to the behavior of the overloaded variants, making them redundant so I've removed them now in #53. Could you give it another try with the changes?

hmf commented 1 year ago

@sbrunk Changes work fine. Thanks.

hmf commented 1 year ago

I need the use of Dropout. In Python this seems to return a constructor of sorts (did not check), which can then be applied to a Tensor.

I see that we have a torch.nn.Dropout that is private to the torch package. So the more obvious solution of having a public Dropout class and its companion object will require changes. I have the following questions:

  1. Is the suggested change above ok?
  2. If so, can I go ahead and change this?
  3. If not, what is the storch way?

EDIT 1:

@davoclavo I realized you have already defined Dropout. I searched your repo but did not find it. Were did you define it? TIA

hmf commented 1 year ago

I would like to use register_buffer. According to the Python API doc, we must pass in a name.

Looking at the org.bytedeco.pytorch.Module we have:

  public Tensor register_buffer(BytePointer name, Tensor tensor) { return asModule()._register_buffer(name, tensor); }
  private native @ByRef @Name("register_buffer") Tensor _register_buffer(@StdString BytePointer name, @ByVal Tensor tensor);
  public Tensor register_buffer(String name, Tensor tensor) { return asModule()._register_buffer(name, tensor); }
  private native @ByRef @Name("register_buffer") Tensor _register_buffer(@StdString String name, @ByVal Tensor tensor);

So in torch.nn.modules.Module something like this should work:

  def registerB[D <: DType](n: String, t: Tensor[D]): Tensor[D] =
    nativeModule.register_buffer(n, t.native)
    t

However, as an example:

  def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using
      name: sourcecode.Name
  ): Tensor[D] =
    nativeModule.register_parameter(name.value, t.native, requiresGrad)
    t

the name is implicitly defined. Is there any way I can keep the implicit but still allow manually setting that name?

On a related not, shouldn't these functions return a Tensor(t). We are assuming the same tensor is returned, but this is not guaranteed.

EDIT 1: we also have the problem of duplicate overload methods due to the use of defaults. What is the way to solve this here? Can I change the names?

EDIT 2: In the meantime I will use:


  def buffer[D <: DType](t: Tensor[D], n: String="")(using
      name: sourcecode.Name
  ): Tensor[D] =
    val name_ = if n.trim().isEmpty() then name.value else n.trim()
    Tensor( nativeModule.register_buffer(n, t.native) )

TIA

sbrunk commented 1 year ago

I need the use of Dropout. In Python this seems to return a constructor of sorts (did not check), which can then be applied to a Tensor.

I see that we have a torch.nn.Dropout that is private to the torch package. So the more obvious solution of having a public Dropout class and its companion object will require changes. I have the following questions:

1. Is the suggested change above ok?

2. If so, can I go ahead and change this?

3. If not, what is the `storch` way?

I think what you found is the Dropout trait in torch.nn.functional right? The trait is private because it's members are exposed through the package object, so you can call it like this:

torch.nn.functional.dropout(input=torch.rand(Seq(3,3)))
// res2: Tensor[Float32] = tensor dtype=float32, shape=[3, 3], device=CPU 
// [[0,4759, 1,4497, 1,7002],
//  [1,2299, 0,0000, 1,1805],
//  [0,0000, 0,0000, 0,0000]]

It corresponds to torch.nn.functional.dropout in Python.

Seems like we're still missing the module variant of Dropout, which corresponds to the Python module you linked to. If you'd like to add that, that would be great! We should put it be under torch.nn.modules somewhere, like the other modules.

sbrunk commented 1 year ago

So in torch.nn.modules.Module something like this should work:

  def registerB[D <: DType](n: String, t: Tensor[D]): Tensor[D] =
    nativeModule.register_buffer(n, t.native)
    t

However, as an example:

  def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using
      name: sourcecode.Name
  ): Tensor[D] =
    nativeModule.register_parameter(name.value, t.native, requiresGrad)
    t

the name is implicitly defined. Is there any way I can keep the implicit but still allow manually setting that name?

We could add an explicit optional name parameter, i.e. defaulting to an empty string, or using an Option. If the caller provides a real name, we take that, otherwise, we fall back to the implicit. Ah I see you've just done that below in the buffer impl :)

On a related not, shouldn't these functions return a Tensor(t). We are assuming the same tensor is returned, but this is not guaranteed.

You're right, it's better to use the tensor returned by the native register method.

EDIT 1: we also have the problem of duplicate overload methods due to the use of defaults. What is the way to solve this here? Can I change the names?

Yes please go ahead. Perhaps we can keep register for modules, because it is used quite often, but use registerParameter, registerBuffer for the others.

EDIT 2: In the meantime I will use:

  def buffer[D <: DType](t: Tensor[D], n: String="")(using
      name: sourcecode.Name
  ): Tensor[D] =
    val name_ = if n.trim().isEmpty() then name.value else n.trim()
    Tensor( nativeModule.register_buffer(n, t.native) )

πŸ‘

davoclavo commented 1 year ago

@davoclavo I realized you have already defined Dropout. I searched your repo but did not find it. Were did you define it? TIA

Hi @hmf ! Apologies for the confusion, I have not committed my changes yet, as I have a bunch of other stuff that needs to be cleaned up. I just shared them in my previous comment to partially share the progress in case it was useful to you :)

You should be able to either drop in that code I shared in your script/example, or add it as a new module to storch.

I'll keep my ear open in case you need any further help, and hopefully find some time soon to help out to contribute these modules to storch.

hmf commented 1 year ago

While trying to implement and debug the multi-head attention mechanism, I have what seems to be unexpected behavior. For a model with the multi-head "only", the code:

    val nuParams = m.parameters.map(_.numel).sum
    println(s"${nuParams} parameters")

Reports:

Multi-head attention
4481 parameters

Now to this model I add the following layer:

    val ffwd = register( FeedFoward(nEmbed) )

where nEmbed = 32. If I count the number of parameters of this layer I get 1056 (nEmbed*nEmbed + nEmbed), which is correct. But the model still reports:

Multi-head attention + FFWD
4481 parameters

Shouldn't that be 4481 + 1056?

TIA

sbrunk commented 1 year ago

@hmf I have a hunch (not tested). Could you try to wrap your Sequential in your feed forward module inside a register as well like so:

https://github.com/sbrunk/storch/blob/5e1fdf2a7b2d985a58ee7a6f8405cd8d443426b4/examples/src/main/scala/gpt/BiGram.scala#L1316-L1326

- val net = nn.Sequential(
+ val net = register(nn.Sequential(

Right now it's registering the layers inside Sequential as submodules of net, but not net itself as a submodule of FeedForward. In Python this is done implicitly. Perhaps we need a macro at some point to achieve s.th. similar in Storch as well.

hmf commented 1 year ago

@sbrunk I have confirmed that I need to register the inner modules. As for the macro, maybe a single function that traverses the sub-modules and registers them would do. But we also have parameter and buffer registering, so that would also have to dealt with.

Thanks.

hmf commented 1 year ago

I would like to give an update on this endeavor. I have gone through most of the video and am now at the start of the "Block" implementation. I have tried to stick to the video so that I can compare my results. Unfortunately my results show much higher loss (single head and multi head of 3).

Here are some results:

Single head

Triple Head

I have run about 9 experiments on CPU. Even though convergence is slow, the good news is that it seems to be stable. See below.

Single Head

lr = 1e-5
Output: step 0: train loss 4.315746, val loss 4.3061743 step 500: train loss 4.2083063, val loss 4.2047343 step 1000: train loss 4.109281, val loss 4.1095076 step 1500: train loss 4.024676, val loss 4.021858 step 2000: train loss 3.9401476, val loss 3.9419503 step 2500: train loss 3.861138, val loss 3.868681 step 3000: train loss 3.7746782, val loss 3.7817297 step 3500: train loss 3.6901476, val loss 3.7049506 step 4000: train loss 3.599073, val loss 3.617259 step 4500: train loss 3.5131109, val loss 3.5384142 step 5000: train loss 3.452971, val loss 3.4619794 step 5500: train loss 3.399948, val loss 3.4254942 step 6000: train loss 3.3541067, val loss 3.3918 step 6500: train loss 3.3242495, val loss 3.3732038 step 7000: train loss 3.3144944, val loss 3.3490424 step 7500: train loss 3.2901514, val loss 3.2941566 step 8000: train loss 3.2899778, val loss 3.308439 step 8500: train loss 3.2639534, val loss 3.2906058 step 9000: train loss 3.2651227, val loss 3.2723944 step 9500: train loss 3.2395923, val loss 3.2861238 step 10000: train loss 3.2434728, val loss 3.257814 step 10500: train loss 3.2285821, val loss 3.23281 step 11000: train loss 3.2198544, val loss 3.2416165 step 11500: train loss 3.2021954, val loss 3.2313745 step 12000: train loss 3.195072, val loss 3.2142315 step 12500: train loss 3.1960852, val loss 3.2163675 step 13000: train loss 3.1769931, val loss 3.2013638 step 13500: train loss 3.17453, val loss 3.2119668 step 14000: train loss 3.1472147, val loss 3.1825323 step 14500: train loss 3.1611233, val loss 3.192211 step 15000: train loss 3.1517265, val loss 3.1621974 step 15500: train loss 3.1394618, val loss 3.1598687 step 16000: train loss 3.1233463, val loss 3.145328 step 16500: train loss 3.1227674, val loss 3.1421418 step 17000: train loss 3.1164768, val loss 3.1276824 step 17500: train loss 3.1011841, val loss 3.0985348 step 18000: train loss 3.0856524, val loss 3.11533 step 18500: train loss 3.0842745, val loss 3.0987678 step 19000: train loss 3.049956, val loss 3.1043591 step 19500: train loss 3.0564034, val loss 3.0689766 step 20000: train loss 3.0590668, val loss 3.0758286 step 20500: train loss 3.0560205, val loss 3.0690722 step 21000: train loss 3.0467145, val loss 3.0635276 step 21500: train loss 3.0318224, val loss 3.0459983 step 22000: train loss 3.025454, val loss 3.0337 step 22500: train loss 3.0058165, val loss 3.0480902 step 23000: train loss 3.0240664, val loss 3.0332391 step 23500: train loss 2.9987218, val loss 3.023562 step 24000: train loss 2.985587, val loss 3.0277314 step 24500: train loss 2.9775257, val loss 3.002483 step 24999: train loss 2.9854958, val loss 3.0055265 step 24999: train loss 2.9771202, val loss 3.0027666

Triple Head

learningRate = 1.0E-5 
maxIterations = 75000
Output: step 0: train loss 4.1618342, val loss 4.16153 step 500: train loss 4.1205373, val loss 4.1242867 step 1000: train loss 4.0790596, val loss 4.081698 step 1500: train loss 4.03232, val loss 4.0372114 step 2000: train loss 3.9790084, val loss 3.9862146 step 2500: train loss 3.9226956, val loss 3.9263957 step 3000: train loss 3.8504639, val loss 3.8638783 step 3500: train loss 3.7733784, val loss 3.786392 step 4000: train loss 3.6981156, val loss 3.720096 step 4500: train loss 3.628634, val loss 3.6443036 step 5000: train loss 3.5587113, val loss 3.5619648 step 5500: train loss 3.4964852, val loss 3.4965785 step 6000: train loss 3.421188, val loss 3.4415543 step 6500: train loss 3.362888, val loss 3.391549 step 7000: train loss 3.3282533, val loss 3.3622048 step 7500: train loss 3.320427, val loss 3.3560946 step 8000: train loss 3.292354, val loss 3.2881603 step 8500: train loss 3.2728596, val loss 3.2815585 step 9000: train loss 3.2583148, val loss 3.2723749 step 9500: train loss 3.2506166, val loss 3.2684808 step 10000: train loss 3.2148948, val loss 3.2601957 step 10500: train loss 3.1988037, val loss 3.2456586 step 11000: train loss 3.206799, val loss 3.2168744 step 11500: train loss 3.1882236, val loss 3.2074182 step 12000: train loss 3.1804316, val loss 3.213266 step 12500: train loss 3.1568613, val loss 3.1786158 step 13000: train loss 3.1662986, val loss 3.1859655 step 13500: train loss 3.1503942, val loss 3.1711109 step 14000: train loss 3.147156, val loss 3.166469 step 14500: train loss 3.1371622, val loss 3.1470597 step 15000: train loss 3.1360898, val loss 3.1625636 step 15500: train loss 3.1335275, val loss 3.1326685 step 16000: train loss 3.1126864, val loss 3.1321425 step 16500: train loss 3.1063373, val loss 3.107653 step 17000: train loss 3.0943058, val loss 3.1191053 step 17500: train loss 3.0952134, val loss 3.1210895 step 18000: train loss 3.1009033, val loss 3.081947 step 18500: train loss 3.0783198, val loss 3.1051643 step 19000: train loss 3.0771048, val loss 3.0912302 step 19500: train loss 3.0516539, val loss 3.0886228 step 20000: train loss 3.0385761, val loss 3.0701919 step 20500: train loss 3.042524, val loss 3.0808887 step 21000: train loss 3.0532212, val loss 3.0581033 step 21500: train loss 3.0394647, val loss 3.0615013 step 22000: train loss 3.021087, val loss 3.0511756 step 22500: train loss 3.0327508, val loss 3.0316634 step 23000: train loss 3.0150063, val loss 3.044455 step 23500: train loss 3.0176592, val loss 3.0279248 step 24000: train loss 3.0032563, val loss 3.0306866 step 24500: train loss 2.9985764, val loss 3.031719 step 25000: train loss 2.9964828, val loss 3.0107496 step 25500: train loss 2.9989612, val loss 3.0088224 step 26000: train loss 2.9867206, val loss 3.0070848 step 26500: train loss 2.9651825, val loss 3.009421 step 27000: train loss 2.978981, val loss 2.9872468 step 27500: train loss 2.972667, val loss 2.9928696 step 28000: train loss 2.9587805, val loss 2.9770303 step 28500: train loss 2.9506211, val loss 2.9797046 step 29000: train loss 2.9521976, val loss 2.9750147 step 29500: train loss 2.9423668, val loss 2.9667535 step 30000: train loss 2.9549394, val loss 2.9439688 step 30500: train loss 2.9268918, val loss 2.9612598 step 31000: train loss 2.91975, val loss 2.94916 step 31500: train loss 2.9251237, val loss 2.934956 step 32000: train loss 2.9079664, val loss 2.9431274 step 32500: train loss 2.910727, val loss 2.9221253 step 33000: train loss 2.91429, val loss 2.919466 step 33500: train loss 2.9074776, val loss 2.9280725 step 34000: train loss 2.896589, val loss 2.9004114 step 34500: train loss 2.898249, val loss 2.9251142 step 35000: train loss 2.8961527, val loss 2.9172351 step 35500: train loss 2.8839464, val loss 2.9006162 step 36000: train loss 2.8779233, val loss 2.9100876 step 36500: train loss 2.879361, val loss 2.9014406 step 37000: train loss 2.8839698, val loss 2.8981316 step 37500: train loss 2.853179, val loss 2.8779168 step 38000: train loss 2.8649895, val loss 2.894176 step 38500: train loss 2.8693879, val loss 2.8744462 step 39000: train loss 2.8525827, val loss 2.8651721 step 39500: train loss 2.858041, val loss 2.8500593 step 40000: train loss 2.8418512, val loss 2.863662 step 40500: train loss 2.8385842, val loss 2.8543704 step 41000: train loss 2.8311198, val loss 2.8574524 step 41500: train loss 2.825884, val loss 2.8499897 step 42000: train loss 2.8429782, val loss 2.8441114 step 42500: train loss 2.8070157, val loss 2.8388376 step 43000: train loss 2.8123505, val loss 2.842098 step 43500: train loss 2.810964, val loss 2.8345373 step 44000: train loss 2.8263602, val loss 2.830025 step 44500: train loss 2.811398, val loss 2.834848 step 45000: train loss 2.802633, val loss 2.810559 step 45500: train loss 2.8123126, val loss 2.8265247 step 46000: train loss 2.7979581, val loss 2.8048408 step 46500: train loss 2.7967849, val loss 2.8334157 step 47000: train loss 2.7803953, val loss 2.7922354 step 47500: train loss 2.7942781, val loss 2.825274 step 48000: train loss 2.7804523, val loss 2.792919 step 48500: train loss 2.785722, val loss 2.8042364 step 49000: train loss 2.7795138, val loss 2.7809396 step 49500: train loss 2.7776642, val loss 2.7782316 step 50000: train loss 2.769403, val loss 2.7787275 step 50500: train loss 2.7557025, val loss 2.7765558 step 51000: train loss 2.759183, val loss 2.7775955 step 51500: train loss 2.7498598, val loss 2.7687922 step 52000: train loss 2.764737, val loss 2.7726612 step 52500: train loss 2.7710688, val loss 2.7590082 step 53000: train loss 2.7473223, val loss 2.760512 step 53500: train loss 2.7373915, val loss 2.7564347 step 54000: train loss 2.7325678, val loss 2.7411654 step 54500: train loss 2.7540653, val loss 2.752793 step 55000: train loss 2.736955, val loss 2.751245 step 55500: train loss 2.7224433, val loss 2.7364216 step 56000: train loss 2.7233686, val loss 2.74944 step 56500: train loss 2.7202756, val loss 2.7465448 step 57000: train loss 2.7280054, val loss 2.7374096 step 57500: train loss 2.7064633, val loss 2.7330124 step 58000: train loss 2.6934423, val loss 2.7236161 step 58500: train loss 2.6968424, val loss 2.72582 step 59000: train loss 2.6981068, val loss 2.7159605 step 59500: train loss 2.695939, val loss 2.724237 step 60000: train loss 2.6998184, val loss 2.7238555 step 60500: train loss 2.6900072, val loss 2.7078435 step 61000: train loss 2.6998444, val loss 2.7143097 step 61500: train loss 2.6824317, val loss 2.699878 step 62000: train loss 2.678613, val loss 2.6927574 step 62500: train loss 2.695001, val loss 2.7028537 step 63000: train loss 2.6931143, val loss 2.6938097 step 63500: train loss 2.6818473, val loss 2.6830072 step 64000: train loss 2.6860394, val loss 2.6763582 step 64500: train loss 2.6754217, val loss 2.6692927 step 65000: train loss 2.652602, val loss 2.6785924 step 65500: train loss 2.6655686, val loss 2.6759882 step 66000: train loss 2.6485276, val loss 2.6638012 step 66500: train loss 2.6445954, val loss 2.6824584 step 67000: train loss 2.6588178, val loss 2.6743371 step 67500: train loss 2.665208, val loss 2.6798043 step 68000: train loss 2.6643429, val loss 2.6748931 step 68500: train loss 2.6562061, val loss 2.6429644 step 69000: train loss 2.6405647, val loss 2.648562 step 69500: train loss 2.6491652, val loss 2.6551437 step 70000: train loss 2.641609, val loss 2.6503496 step 70500: train loss 2.6256104, val loss 2.6489353 step 71000: train loss 2.6348572, val loss 2.6602316 step 71500: train loss 2.6440005, val loss 2.6452422 step 72000: train loss 2.625387, val loss 2.655331 step 72500: train loss 2.6233087, val loss 2.6433735 step 73000: train loss 2.623311, val loss 2.6347494 step 73500: train loss 2.609082, val loss 2.6489167 step 74000: train loss 2.6275, val loss 2.6279202 step 74500: train loss 2.624021, val loss 2.643931 step 74999: train loss 2.6234972, val loss 2.628585 step 75000: train loss 2.6114013, val loss 2.623228
davoclavo commented 1 year ago

Awesome! Thanks for the update, I'm glad you are making strides on the progress of this endeavor!

It's also very interesting that the loss doesn't match the implementation in python. One thing that might be worth testing is setting the same seed and making sure that all the implemented operations/methods that rely on random numbers are actually using it, I would expect to get the same (and consistent) results in the same CPU architecture for both implementations, and multiple runs on each.

Perhaps there are other things that might come into play to explain why the loss is not the same, I will keep thinking about it. One example that comes to mind is there was a bug with the Torch.randn wrapper a while back, so there possibly could be other minor bugs in other similar operations. If you are keen to share your implementation I might be able to run it on my machine and take a look into why the results aren't matching with the python results.

hmf commented 1 year ago

@davoclavo Thanks for the feedback. In regards to the seed I have tried to do as is in the video, but this may not be correct. Also some of the constants may be off. For the final version I will try to replicate this code, so general performance should match.

As for the causes of differences in loss, when I tested the MNIST example in Linux, its behavior was not the same as in Mac. In fact the process did not converge. This was strange. @sbrunk changed the code altering the learning rate so that learning would converge in both OS.

sbrunk commented 1 year ago

@hmf I'll try to look into this to get a better understanding. Is there anything I need to consider if I want to run it? I.e. I've seen you are using mill, right?

hmf commented 1 year ago

@sbrunk Thanks.

Is there anything I need to consider if I want to run it?

Not really. Simple Scala object. Messy code though. Sorry about that.

I've seen you are using mill, right?

Correct, but it just calls the main. Execution is in the object initialization. Something to correct.

I was hoping to contribute the Mill script (another issue). It is just missing project publishing. When I get time I want to upgrade it to the latest Laika version to avoid the need to override the Helium templates (currently it overrides the header).

hmf commented 1 year ago

I have implemented a clean version of this Python code. It is here. I am able to get a validation error below 2.0 (even less) as shown in the tutorial video. However with an increased number of iterations.

Unfortunately I am unable to use the exact same parameters due to memory issues. I am using a GPU with a whopping 24 Gigabytes. As soon as I start training, CUDA (nvidia-smi) shows over 18 Gigabytes being used. The (smaller) model (13_347_905 parameters) seems to use about 50 Mibs. The training loop uses Pointer.physicalBytes(), which reports a stable 2.2 MiBs after many iterations. So I am at a loss to know why so much memory is used.

I have looked for the APIs but cannot find the calls to get the CUDA memory stats.

Can anyone give me some pointer on how to check were the memory is used and diagnose this issue?

TIA

sbrunk commented 1 year ago

@hmf yeah right now, we don't really have a good way to do memory profiling. Need to look into that too. Perhaps we can use the JavaCPP Cuda bindings to get better GPU memory usage information.

One idea you could try for now is to run only parts of the model (i.e. just the attention layer etc.) inside a training loop (you can use just random inputs of the right size). That might help to isolate better what part consumes so much memory or where it leaks.

hmf commented 1 year ago

@sbrunk thanks for the suggestions. I have started using a kludge to try and get an idea where the memory is being allocated. What I do is set a Thread.sleep at certain parts of the code and use nvidia-smi to check the memory. Not practical, but it has helped me see that I need to add some additional PointerScopes. Still trying to figure out how the memory accumulates so much.

Perhaps we can use the JavaCPP Cuda bindings to get better GPU memory usage information.

I was hoping the PointerScope and Pointer would help. No luck, but still looking into it. Do you by any chance have any examples of explicitly de-allocating Tensors via these interfaces/classes?

What JavaCPP Cuda bindings are you referring to? I quick look at the API does not reveal too much. I also think that we would need to access this information in a device independent manner. Maybe via Device?

EDIT: Device does not seem to be helpful. Need to look at torch.cuda.memory_stats for possible solution.

sbrunk commented 1 year ago

@sbrunk thanks for the suggestions. I have started using a kludge to try and get an idea where the memory is being allocated. What I do is set a Thread.sleep at certain parts of the code and use nvidia-smi to check the memory. Not practical, but it has helped me see that I need to add some additional PointerScopes. Still trying to figure out how the memory accumulates so much.

Does it allocate too much inside a single iteration already or does it grow over multiple iterations during the training loop?

I was hoping the PointerScope and Pointer would help. No luck, but still looking into it. Do you by any chance have any examples of explicitly de-allocating Tensors via these interfaces/classes?

In Storch itself, I think we only have the image classifier example and https://github.com/sbrunk/storch/issues/5, but you seem to be already doing it this way.

The JavaCPP tests for deallocation and PointerScope might be helpful here. There are also a few issues/discussions in the javacpp/javacpp-presets repo like this: https://github.com/bytedeco/javacpp-presets/discussions/1160

What JavaCPP Cuda bindings are you referring to? I quick look at the API does not reveal too much. I also think that we would need to access this information in a device independent manner. Maybe via Device?

The Java bindings to the CUDA toolkit itself. But that's a long shot, I'm not sure if it provides something usable for us here.

EDIT: Device does not seem to be helpful. Need to look at torch.cuda.memory_stats for possible solution.

It looks like LibTorch provides something like this, see: https://discuss.pytorch.org/t/libtorch-equivalent-of-torch-cuda-memory-reserved/165995 But I haven't found anything related in the JavaCPP PyTorch bindings, so it might need to be mapped in the preset. Perhaps you could open an issue in https://github.com/bytedeco/javacpp-presets to ask about it.

hmf commented 1 year ago

@sbrunk thanks for the suggestions. I have started using a kludge to try and get an idea where the memory is being allocated. What I do is set a Thread.sleep at certain parts of the code and use nvidia-smi to check the memory. Not practical, but it has helped me see that I need to add some additional PointerScopes. Still trying to figure out how the memory accumulates so much.

Does it allocate too much inside a single iteration already or does it grow over multiple iterations during the training loop?

It grows as it iterates.

I was hoping the PointerScope and Pointer would help. No luck, but still looking into it. Do you by any chance have any examples of explicitly de-allocating Tensors via these interfaces/classes?

In Storch itself, I think we only have the image classifier example and #5, but you seem to be already doing it this way.

The JavaCPP tests for deallocation and PointerScope might be helpful here. There are also a few issues/discussions in the javacpp/javacpp-presets repo like this: bytedeco/javacpp-presets#1160

I had seen these already. What I have learned is that one of the functions (that calculated the validation and training loss was accumulating memory. I added another PointerScope and now the model runs with the original parameters. The 13_443_137 parameter model still seems to consume too much memory (7 Gibs). Not everyone has a GPU with that memory. Maybe another PointerScope can reduce this at a cost.

What JavaCPP Cuda bindings are you referring to? I quick look at the API does not reveal too much. I also think that we would need to access this information in a device independent manner. Maybe via Device?

The Java bindings to the CUDA toolkit itself. But that's a long shot, I'm not sure if it provides something usable for us here.

Ok. I agree with you.

EDIT: Device does not seem to be helpful. Need to look at torch.cuda.memory_stats for possible solution.

It looks like LibTorch provides something like this, see: https://discuss.pytorch.org/t/libtorch-equivalent-of-torch-cuda-memory-reserved/165995 But I haven't found anything related in the JavaCPP PyTorch bindings, so it might need to be mapped in the preset. Perhaps you could open an issue in https://github.com/bytedeco/javacpp-presets to ask about it.

Ok.

hmf commented 1 year ago

The current V2 implementation is using the same parameters but will not converge using the original learning rate. The only thing that is missing is the weight and bias initialization. The nn.Module does not seem to have an apply like PyTorch that does this.

So what is the best way forward here? Should we include such a method? What should we name it? Do simply iterate through all layers and apply a function like Python does?

Should I open a new issue to discuss this?

TIA

EDIT: above it should read "The current V2 implementation is using the same parameters but will not converge as quickly using the original learning rate."

sbrunk commented 1 year ago

Good idea. A recursive apply like in Python should be quite useful. And yes, please create a new issue for this.

hmf commented 1 year ago

Results of v2 on par with tutorial (loss below 2.0), but slower convergence. After a while it diverges. At the end I show an example of its output.

13443137 parameters
learningRate = 1.0E-4
maxIterations = 67000
dropout = 0.2
GPU total = 24.0 GiB
GPU used = 6.9 GiB
13443137 parameters >= 53772548 bytes = 51.3 MiB
step 0: train loss 4.335848, val loss 4.332262, mem 714.0 MiB @ 00 00:00:00.000, mean 00 00:00:00.000
step 500: train loss 2.5476382, val loss 2.5570214, mem 838.6 MiB @ 00 00:01:01.811, mean 00 00:00:00.123
step 1000: train loss 2.508413, val loss 2.5130005, mem 839.4 MiB @ 00 00:02:03.242, mean 00 00:00:00.122
step 1500: train loss 2.4970562, val loss 2.495719, mem 839.5 MiB @ 00 00:03:04.595, mean 00 00:00:00.122
step 2000: train loss 2.469262, val loss 2.4932768, mem 839.7 MiB @ 00 00:04:05.924, mean 00 00:00:00.122
step 2500: train loss 2.4545126, val loss 2.4732363, mem 839.7 MiB @ 00 00:05:07.223, mean 00 00:00:00.122
step 3000: train loss 2.4397492, val loss 2.4652252, mem 841.5 MiB @ 00 00:06:08.506, mean 00 00:00:00.122
step 3500: train loss 2.4432235, val loss 2.4631853, mem 841.6 MiB @ 00 00:07:09.783, mean 00 00:00:00.122
step 4000: train loss 2.4328675, val loss 2.457826, mem 841.6 MiB @ 00 00:08:11.065, mean 00 00:00:00.122
step 4500: train loss 2.4265091, val loss 2.4551604, mem 841.6 MiB @ 00 00:09:12.319, mean 00 00:00:00.122
step 5000: train loss 2.4185965, val loss 2.450682, mem 841.6 MiB @ 00 00:10:13.570, mean 00 00:00:00.122
step 5500: train loss 2.4081905, val loss 2.447644, mem 841.9 MiB @ 00 00:11:14.834, mean 00 00:00:00.122
step 6000: train loss 2.3958697, val loss 2.4314053, mem 842.1 MiB @ 00 00:12:16.084, mean 00 00:00:00.122
step 6500: train loss 2.381643, val loss 2.4313533, mem 842.1 MiB @ 00 00:13:17.338, mean 00 00:00:00.122
step 7000: train loss 2.364158, val loss 2.4134161, mem 842.1 MiB @ 00 00:14:18.601, mean 00 00:00:00.122
step 7500: train loss 2.3529167, val loss 2.4074175, mem 842.2 MiB @ 00 00:15:19.855, mean 00 00:00:00.122
step 8000: train loss 2.328031, val loss 2.3847246, mem 842.5 MiB @ 00 00:16:21.105, mean 00 00:00:00.122
step 8500: train loss 2.292856, val loss 2.351461, mem 842.5 MiB @ 00 00:17:22.352, mean 00 00:00:00.122
step 9000: train loss 2.2544227, val loss 2.321474, mem 842.5 MiB @ 00 00:18:23.577, mean 00 00:00:00.122
step 9500: train loss 2.219748, val loss 2.2897422, mem 842.5 MiB @ 00 00:19:24.806, mean 00 00:00:00.122
step 10000: train loss 2.1745658, val loss 2.2487366, mem 842.5 MiB @ 00 00:20:25.992, mean 00 00:00:00.122
step 10500: train loss 2.1545537, val loss 2.235534, mem 842.5 MiB @ 00 00:21:27.209, mean 00 00:00:00.122
step 11000: train loss 2.13079, val loss 2.2194557, mem 842.5 MiB @ 00 00:22:28.427, mean 00 00:00:00.122
step 11500: train loss 2.107516, val loss 2.1982605, mem 842.5 MiB @ 00 00:23:29.635, mean 00 00:00:00.122
step 12000: train loss 2.085714, val loss 2.1769443, mem 842.5 MiB @ 00 00:24:30.833, mean 00 00:00:00.122
step 12500: train loss 2.0651603, val loss 2.1682646, mem 842.5 MiB @ 00 00:25:32.036, mean 00 00:00:00.122
step 13000: train loss 2.04403, val loss 2.145483, mem 842.6 MiB @ 00 00:26:33.255, mean 00 00:00:00.122
step 13500: train loss 2.0215368, val loss 2.1287656, mem 842.8 MiB @ 00 00:27:34.475, mean 00 00:00:00.122
step 14000: train loss 2.073082, val loss 2.1562624, mem 842.8 MiB @ 00 00:28:35.691, mean 00 00:00:00.122
step 14500: train loss 2.0549197, val loss 2.1463277, mem 842.8 MiB @ 00 00:29:36.914, mean 00 00:00:00.122
step 15000: train loss 2.0292356, val loss 2.1369631, mem 842.8 MiB @ 00 00:30:38.123, mean 00 00:00:00.122
step 15500: train loss 2.0073128, val loss 2.1167858, mem 842.8 MiB @ 00 00:31:39.302, mean 00 00:00:00.122
step 16000: train loss 1.987694, val loss 2.1022239, mem 842.8 MiB @ 00 00:32:40.479, mean 00 00:00:00.122
step 16500: train loss 1.9841061, val loss 2.0968378, mem 842.8 MiB @ 00 00:33:41.664, mean 00 00:00:00.122
step 17000: train loss 1.964174, val loss 2.0827105, mem 842.8 MiB @ 00 00:34:42.826, mean 00 00:00:00.122
step 17500: train loss 1.9512877, val loss 2.0708296, mem 843.1 MiB @ 00 00:35:43.998, mean 00 00:00:00.122
step 18000: train loss 1.9287692, val loss 2.0533903, mem 843.3 MiB @ 00 00:36:45.177, mean 00 00:00:00.122
step 18500: train loss 1.9105072, val loss 2.0451093, mem 843.3 MiB @ 00 00:37:46.354, mean 00 00:00:00.122
step 19000: train loss 1.8970441, val loss 2.0320392, mem 843.3 MiB @ 00 00:38:47.511, mean 00 00:00:00.122
step 19500: train loss 1.8854179, val loss 2.0191305, mem 843.6 MiB @ 00 00:39:48.674, mean 00 00:00:00.122
step 20000: train loss 1.8791237, val loss 2.0184035, mem 843.6 MiB @ 00 00:40:49.850, mean 00 00:00:00.122
step 20500: train loss 1.8812588, val loss 2.0221977, mem 843.6 MiB @ 00 00:41:51.024, mean 00 00:00:00.122
step 21000: train loss 1.903744, val loss 2.030137, mem 843.6 MiB @ 00 00:42:52.190, mean 00 00:00:00.122
step 21500: train loss 1.8765324, val loss 2.0156875, mem 843.7 MiB @ 00 00:43:53.337, mean 00 00:00:00.122
step 22000: train loss 1.8608563, val loss 2.0002515, mem 843.7 MiB @ 00 00:44:54.489, mean 00 00:00:00.122
step 22500: train loss 1.8509007, val loss 1.9926138, mem 843.7 MiB @ 00 00:45:55.633, mean 00 00:00:00.122
step 23000: train loss 1.8334926, val loss 1.9830146, mem 843.7 MiB @ 00 00:46:56.772, mean 00 00:00:00.122
step 23500: train loss 1.8383644, val loss 1.9679743, mem 843.7 MiB @ 00 00:47:57.908, mean 00 00:00:00.122
step 24000: train loss 1.8247951, val loss 1.9709209, mem 843.7 MiB @ 00 00:48:59.069, mean 00 00:00:00.122
step 24500: train loss 1.8079953, val loss 1.9558063, mem 843.8 MiB @ 00 00:50:00.229, mean 00 00:00:00.122
step 25000: train loss 1.8013923, val loss 1.9549898, mem 843.8 MiB @ 00 00:51:01.393, mean 00 00:00:00.122
step 25500: train loss 1.7907901, val loss 1.945321, mem 843.8 MiB @ 00 00:52:02.553, mean 00 00:00:00.122
step 26000: train loss 1.7799153, val loss 1.939117, mem 843.8 MiB @ 00 00:53:03.719, mean 00 00:00:00.122
step 26500: train loss 1.7685446, val loss 1.9267551, mem 843.8 MiB @ 00 00:54:04.868, mean 00 00:00:00.122
step 27000: train loss 1.7624547, val loss 1.9231879, mem 843.8 MiB @ 00 00:55:06.017, mean 00 00:00:00.122
step 27500: train loss 1.7491188, val loss 1.9149151, mem 844.1 MiB @ 00 00:56:07.160, mean 00 00:00:00.122
step 28000: train loss 1.7429442, val loss 1.9086628, mem 844.1 MiB @ 00 00:57:08.307, mean 00 00:00:00.122
step 28500: train loss 1.7355636, val loss 1.9029418, mem 844.1 MiB @ 00 00:58:09.454, mean 00 00:00:00.122
step 29000: train loss 1.7248852, val loss 1.8960071, mem 844.1 MiB @ 00 00:59:10.608, mean 00 00:00:00.122
step 29500: train loss 1.7195947, val loss 1.8904512, mem 844.1 MiB @ 00 01:00:11.752, mean 00 00:00:00.122
step 30000: train loss 1.7153524, val loss 1.8848493, mem 844.1 MiB @ 00 01:01:12.904, mean 00 00:00:00.122
step 30500: train loss 1.7048767, val loss 1.8789163, mem 844.1 MiB @ 00 01:02:14.060, mean 00 00:00:00.122
step 31000: train loss 1.694385, val loss 1.870024, mem 844.1 MiB @ 00 01:03:15.210, mean 00 00:00:00.122
step 31500: train loss 1.6884319, val loss 1.8608154, mem 844.1 MiB @ 00 01:04:16.349, mean 00 00:00:00.122
step 32000: train loss 1.6768422, val loss 1.8586318, mem 844.1 MiB @ 00 01:05:17.483, mean 00 00:00:00.122
step 32500: train loss 1.6761434, val loss 1.8587543, mem 844.1 MiB @ 00 01:06:18.619, mean 00 00:00:00.122
step 33000: train loss 1.6758552, val loss 1.8544992, mem 844.1 MiB @ 00 01:07:19.746, mean 00 00:00:00.122
step 33500: train loss 1.67037, val loss 1.8574976, mem 844.1 MiB @ 00 01:08:20.870, mean 00 00:00:00.122
step 34000: train loss 1.6646343, val loss 1.8511721, mem 844.1 MiB @ 00 01:09:21.990, mean 00 00:00:00.122
step 34500: train loss 1.6610796, val loss 1.8486292, mem 844.3 MiB @ 00 01:10:23.103, mean 00 00:00:00.122
step 35000: train loss 1.6537488, val loss 1.8431506, mem 844.3 MiB @ 00 01:11:24.236, mean 00 00:00:00.122
step 35500: train loss 1.6544412, val loss 1.843468, mem 844.3 MiB @ 00 01:12:25.375, mean 00 00:00:00.122
step 36000: train loss 1.6563864, val loss 1.842051, mem 844.3 MiB @ 00 01:13:26.514, mean 00 00:00:00.122
step 36500: train loss 1.6723832, val loss 1.8542444, mem 844.3 MiB @ 00 01:14:27.646, mean 00 00:00:00.122
step 37000: train loss 1.6729113, val loss 1.8599828, mem 844.6 MiB @ 00 01:15:28.785, mean 00 00:00:00.122
step 37500: train loss 1.657896, val loss 1.8432986, mem 844.6 MiB @ 00 01:16:29.928, mean 00 00:00:00.122
step 38000: train loss 1.6419864, val loss 1.8300749, mem 845.2 MiB @ 00 01:17:31.076, mean 00 00:00:00.122
step 38500: train loss 1.6395802, val loss 1.831336, mem 845.4 MiB @ 00 01:18:32.209, mean 00 00:00:00.122
step 39000: train loss 1.6333517, val loss 1.8239709, mem 845.4 MiB @ 00 01:19:33.320, mean 00 00:00:00.122
step 39500: train loss 1.6248128, val loss 1.8159283, mem 845.4 MiB @ 00 01:20:34.444, mean 00 00:00:00.122
step 40000: train loss 1.6188323, val loss 1.8165076, mem 845.4 MiB @ 00 01:21:35.571, mean 00 00:00:00.122
step 40500: train loss 1.6140128, val loss 1.8128036, mem 845.4 MiB @ 00 01:22:36.716, mean 00 00:00:00.122
step 41000: train loss 1.6085365, val loss 1.8036649, mem 850.8 MiB @ 00 01:23:37.860, mean 00 00:00:00.122
step 41500: train loss 1.6002386, val loss 1.8010824, mem 850.8 MiB @ 00 01:24:38.994, mean 00 00:00:00.122
step 42000: train loss 1.5997845, val loss 1.803601, mem 850.8 MiB @ 00 01:25:40.133, mean 00 00:00:00.122
step 42500: train loss 1.9896176, val loss 2.08117, mem 850.8 MiB @ 00 01:26:41.304, mean 00 00:00:00.122
step 43000: train loss 2.6600757, val loss 2.7321503, mem 850.8 MiB @ 00 01:27:42.501, mean 00 00:00:00.122
step 43500: train loss 3.4551952, val loss 3.4863732, mem 850.8 MiB @ 00 01:28:43.643, mean 00 00:00:00.122
step 44000: train loss 3.44979, val loss 3.486229, mem 850.8 MiB @ 00 01:29:44.789, mean 00 00:00:00.122
step 44500: train loss 3.426718, val loss 3.467198, mem 850.8 MiB @ 00 01:30:45.911, mean 00 00:00:00.122
step 45000: train loss 3.4057078, val loss 3.4453905, mem 851.3 MiB @ 00 01:31:46.969, mean 00 00:00:00.122
step 45500: train loss 3.3713775, val loss 3.416267, mem 851.3 MiB @ 00 01:32:47.956, mean 00 00:00:00.121
step 46000: train loss 3.386201, val loss 3.4285066, mem 851.3 MiB @ 00 01:33:48.900, mean 00 00:00:00.121
step 46500: train loss 3.368285, val loss 3.4076788, mem 851.3 MiB @ 00 01:34:49.797, mean 00 00:00:00.121

Output (removed initial white spaces, too many):

GLOUCESTER:
Meaved:
Le, loopeak thather miscapio, and Caps noues?
Buch willift my freat I r'd did

LANDWAGLOUD INA:
Foolvest Hapry, shall he no cacinf it pake thou.
Are withe ret withsks shat usant:
Why somein so, helly how.
G Edwars haven, af of my unt: and,
To douer naty and babt! age she woun wilcy stith fabasten ther na s,
To beethe n as ave thigh bid sty are in in is
To nevirly thed, and our aft shal well, couser's be thringe.

ROMET:
For that ther sthing the meet is than him!
's.

HMOPSAMP'
sbrunk commented 1 year ago

@hmf amazing work!

I got hands on a larger GPU now and will start playing with it.

dejvid commented 12 months ago

Is it possible to get a code for this example? It would be an excellent example project to learn from if it were posted.

hmf commented 12 months ago

@dejvid The examples are in this fork. It is not completed yet - still needs weight initialization (see issue #61). I have had trouble finishing due to an update to the Pytorch. See issue #62

hmf commented 11 months ago

@sbrunk I have created a clean branch with the changes that implement the example. This holds changes for #51 and #61 (apologies for the comment error in the commit). Some notes in the implementation:

  1. It does not have the performance of the original Python code. I have worked on this for a long time but cannot figure out what is wrong. Needs reviewing.
  2. The use of weight initialization make performance significantly worse. I have left it in. Maybe it should be removed or corrected (might not be the intended/best initialization).

I now have another issue. After updating to your latest changes (2.1.1), GPU dos not work on my side. This is also true for the original LeNet example (nvidia-smi shows no process using GPU). Note that because of this, current code breaks, but it is not usable without a GPU.

EDIT: forgot to mention that with version 2.1.0 I had some memory issues I did not have with the previous version. In particular I implemented a wrapper for memory_stats. I think it is necessary to add the memory management functions (such as clearing the cache) to allow us to use storch effectively.

Could check this and give me feedback?

TIA

sbrunk commented 11 months ago

Thanks @hmf for pushing this forward. Could you create a PR from your branch? That should make it easier to do reviewing.

I'll try to figure out the new GPU issue.

hmf commented 11 months ago

@sbrunk The code assumes a GPU is available. The printMemoryInfo will fail otherwise. Remove it if you need to test with CPU only.

The class is gpt.V2.

sbrunk commented 10 months ago

I now have another issue. After updating to your latest changes (2.1.1), GPU dos not work on my side. This is also true for the original LeNet example (nvidia-smi shows no process using GPU). Note that because of this, current code breaks, but it is not usable without a GPU.

I can't reproduce the GPU issue at the moment, but I could only try on an RTX 4090 so far, which is ADA architecture, while 3090 is Ampere.

It did work with 2.1.0 for you right?

Could you give at a try with the latest update to PyTorch 2.1.2 by bumping the PyTorch patch version in build.sbt?

- val pytorchVersion = "2.1.1"
+ val pytorchVersion = "2.1.2"
hmf commented 10 months ago

@sbrunk

GPU working with 2.1.2

Note that with:

set ThisBuild / enableGPU := true

In sbt I now get:

[error] stack trace is suppressed; run last core / update for the full output
[error] (core / update) lmcoursier.internal.shaded.coursier.error.FetchError$DownloadingArtifacts: Error fetching artifacts:
[error] https://oss.sonatype.org/content/repositories/snapshots/org/bytedeco/pytorch/2.1.1-1.5.10-SNAPSHOT/pytorch-2.1.1-1.5.10-20231204.171720-12-linux-x86_64-gpu.jar: not found: https://oss.sonatype.org/content/repositories/snapshots/org/bytedeco/pytorch/2.1.1-1.5.10-SNAPSHOT/pytorch-2.1.1-1.5.10-20231204.171720-12-linux-x86_64-gpu.jar

But no issues with 2.1.2.

Thanks.

EDIT: yes it is working with 2.1.0