stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

Add Conv and ConvTranspose #42

Closed dlwh closed 12 months ago

dlwh commented 1 year ago

Fixes #23

leizaf commented 11 months ago
class _ConvBase(eqx.Module):
    """
    Base class for Conv and ConvTranspose. Mostly just contains shared code.
    """

    Spatial: tuple[str | Axis, ...] = eqx.field(static=True)
    In: Axis = eqx.field(static=True)
    Out: Axis = eqx.field(static=True)
    weight: NamedArray = eqx.field(static=True)
    bias: Optional[NamedArray] = eqx.field(static=True)

How come weight and bias are static fields? Doesn't this prevent grad? It also prevents me from casting their type.

dlwh commented 11 months ago

Ah crap. Wasn’t careful and got bit by copilot

On Fri, Oct 13, 2023 at 5:42 PM Lei @.***> wrote:

class _ConvBase(eqx.Module): """ Base class for Conv and ConvTranspose. Mostly just contains shared code. """

Spatial: tuple[str | Axis, ...] = eqx.field(static=True)
In: Axis = eqx.field(static=True)
Out: Axis = eqx.field(static=True)
weight: NamedArray = eqx.field(static=True)
bias: Optional[NamedArray] = eqx.field(static=True)

How come weight and bias are static fields? Doesn't this prevent grad? It also prevents me from casting their type.

— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/pull/42#issuecomment-1762429029, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLIN5OUUSXJG4V2G7EV3X7HNWRAVCNFSM6AAAAAA5VBAEK2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONRSGQZDSMBSHE . You are receiving this because you modified the open/close state.Message ID: @.***>