crystal-data / num.cr

Scientific computing in pure Crystal
MIT License
151 stars 12 forks source link

Add grad gates for trigonometric functions #48

Closed christopherzimmerman closed 3 years ago

christopherzimmerman commented 4 years ago

It would be helpful to have gates for trig functions with simple derivatives, such as sin, cos, etc. The general pattern can be viewed in grad.arithmetic_ops. A Gate must have a backward method, and a cache method. So for example, the sin, gate would look like:

class Num::Grad::SinGate(T) < Num::Grad::Gate(T)
  getter a : Num::Grad::Variable(T)

  def initialize(@a : Num::Grad::Variable(T))
  end

  def backward(payload : Num::Grad::Payload(T)) : Array(T)
    gradient = payload.variable.grad
    r0 = gradient.map(a.value) do |i, j|
      i * Math.cos(j)
    end
    [r0]
  end

  def cache(result : Num::Grad::Variable(T), *args)
    a = args[0]
    result.grad = T.zeros_like(result.value)
    result.requires_grad = true
    Num::Grad.register("Sin", self, result, a)
  end
end

After the gate is created, an operator should be added directly to Variable, calling this function and cacheing it on a context:

class Num::Grad::Variable(T)
  def sin : Num::Grad::Variable(T)
    result = @context.variable(@value.sin)

    if self.is_grad_needed
      gate = Num::Grad::SinGate.new(result)
      gate.cache(result, self)
    end
    result
  end
end

Testing the derivative of the sin function:

ctx = Num::Grad::Context(Tensor(Float64)).new
t = [0.0, Math::PI].to_tensor
a = ctx.variable(t)
b = a.sin
b.backprop
puts a.grad
# [1, 1]
christopherzimmerman commented 3 years ago

Closed by #56