It would be helpful to have gates for trig functions with simple derivatives, such as sin, cos, etc. The general pattern can be viewed in grad.arithmetic_ops. A Gate must have a backward method, and a cache method. So for example, the sin, gate would look like:
class Num::Grad::SinGate(T) < Num::Grad::Gate(T)
getter a : Num::Grad::Variable(T)
def initialize(@a : Num::Grad::Variable(T))
end
def backward(payload : Num::Grad::Payload(T)) : Array(T)
gradient = payload.variable.grad
r0 = gradient.map(a.value) do |i, j|
i * Math.cos(j)
end
[r0]
end
def cache(result : Num::Grad::Variable(T), *args)
a = args[0]
result.grad = T.zeros_like(result.value)
result.requires_grad = true
Num::Grad.register("Sin", self, result, a)
end
end
After the gate is created, an operator should be added directly to Variable, calling this function and cacheing it on a context:
class Num::Grad::Variable(T)
def sin : Num::Grad::Variable(T)
result = @context.variable(@value.sin)
if self.is_grad_needed
gate = Num::Grad::SinGate.new(result)
gate.cache(result, self)
end
result
end
end
Testing the derivative of the sin function:
ctx = Num::Grad::Context(Tensor(Float64)).new
t = [0.0, Math::PI].to_tensor
a = ctx.variable(t)
b = a.sin
b.backprop
puts a.grad
# [1, 1]
It would be helpful to have gates for trig functions with simple derivatives, such as
sin
,cos
, etc. The general pattern can be viewed ingrad.arithmetic_ops
. A Gate must have abackward
method, and a cache method. So for example, thesin
, gate would look like:After the gate is created, an operator should be added directly to
Variable
, calling this function and cacheing it on a context:Testing the derivative of the
sin
function: