AccelerateHS / accelerate-tensorflow

BSD 3-Clause "New" or "Revised" License
3 stars 3 forks source link

Add `tabulate` function #18

Open bwijgers opened 2 years ago

bwijgers commented 2 years ago

This corresponds to ZipCat in ConCat, which should correspond with generate in Accelerate.

bwijgers commented 2 years ago

Implementing tabulate should indeed be possible with a generate. Using generate sh f, f will get the shape as an argument; it should be simple to determine the value for that (multi-dimensional) index.

bwijgers commented 2 years ago

I didn't realize that the idea was that we provide these functions before closing this issue; I'll reopen this and implement such a function.

bwijgers commented 2 years ago

generate is currently not supported, we need to fix that.

tmcdonell commented 2 years ago

Supporting a fully general generate function is difficult (corresponds to the full lifting transformation, which I started in accelerate-tensorflow/icebox/Vectorise.hs but it is far from complete). If we can get some examples for how it is used though we might be able to support the specific functionality that you need without implementing the nuclear option (in the same way that filling an array with a constant value is also a generate, but we can look for this pattern and implement it fairly easily). @mikesperber

mikesperber commented 2 years ago

I think at the top you mean RepresentableCat, which has a tabulateC method right?

mikesperber commented 2 years ago

Here are the uses we have for it - note that the nested functors will eventually end up as 2-dimensional matrices:

-- | zero-functor with a `one` at index `i`.
oneHot ::
  ( Rep.Representable f,
    Num a,
    Eq (Rep.Rep f)
  ) =>
  Rep.Rep f ->
  f a
oneHot i = Rep.tabulate (\i' -> if i == i' then 1 else 0)

-- | Generalized identity matrix.
idM ::
  ( Rep.Representable f,
    Num a,
    Eq (Rep.Rep f)
  ) =>
  f (f a)
idM = Rep.tabulate oneHot

And here's the other place:

-- | Create a functor of all kernel positions of a convolution.
makeKernelPosFunctor ::
  forall
    dilationRate
    kernelPosFunctor
    windowFunctor
    kernelFunctor.
  ( KnownNat dilationRate,
    Integral (Rep.Rep windowFunctor),
    Integral (Rep.Rep kernelPosFunctor),
    Integral (Rep.Rep kernelFunctor),
    Pointed kernelPosFunctor,
    Zip.Zip kernelPosFunctor,
    Rep.Representable windowFunctor,
    Rep.Representable kernelFunctor,
    Rep.Representable kernelPosFunctor,
    Functor kernelPosFunctor
  ) =>
  kernelPosFunctor (kernelFunctor (Rep.Rep windowFunctor))
makeKernelPosFunctor =
  Rep.tabulate (makeIndexFunctor @windowFunctor @kernelPosFunctor dilationRate)
  where
    dilationRate = toInteger (natVal (Proxy @dilationRate))

-- | For an index (`idx`) of an output functor of a convolution, create the
-- corresponding kernel position, i.e. a `kernelFunctor` of indices of an input
-- functor.
-- Shifts all indices of `kenerlFunctor` by `offset` (i.e. the position of the
-- kernel's first entry) plus `dilationRate` times the indices.
makeIndexFunctor ::
  ( Rep.Representable windowInFunctor,
    Rep.Representable windowOutFunctor,
    Rep.Representable kernelFunctor,
    Integral (Rep.Rep windowInFunctor),
    Integral (Rep.Rep windowOutFunctor),
    Integral (Rep.Rep kernelFunctor),
    Integral dilationRate
  ) =>
  dilationRate ->
  Rep.Rep windowOutFunctor ->
  kernelFunctor (Rep.Rep windowInFunctor)
makeIndexFunctor dilationRate idx = Rep.tabulate tabulator
  where
    offSet = fromIntegral idx
    tabulator x =
      fromIntegral (shiftIndex (fromIntegral dilationRate) offSet (fromIntegral x))

-- | Transform an index (`idx`) of a `kernelFunctor` to the corresponding index of
-- an input functor for a given kernel position (specified by `offSet`).
shiftIndex :: Integer -> Integer -> Integer -> Integer
shiftIndex dilationRate offSet idx = offSet + dilationRate * idx

Does this help?

mikesperber commented 1 year ago

We're not sure we need this one, so it's OK to defer.