apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.77k stars 6.8k forks source link

make symbol's grad as net's output #19650

Open troyliu0105 opened 3 years ago

troyliu0105 commented 3 years ago

Description

I want to make a GAM for symbol block. I just wonder whether here exists an api to do this, like below:

ipt = mx.sym.var('data')
out = net(ipt)
internal = out.get_internals(internal_name)
grad_op = mx.sym.get_grad(internal)
net = SymbolBlock(mx.sym.Group([grad_op, out]), ipt, net.collect_params())
data = mx.nd.random_uniform(shape=(1, 3, 224, 224))
with autograd.record():
    output = net(data)
output[1].backward()

# and you can retrieve grad of "internal" 
internal_grad = output[0]

😃

github-actions[bot] commented 3 years ago

Welcome to Apache MXNet (incubating)! We are on a mission to democratize AI, and we are glad that you are contributing to it by opening this issue. Please make sure to include all the relevant context, and one of the @apache/mxnet-committers will be here shortly. If you are interested in contributing to our project, let us know! Also, be sure to check out our guide on contributing to MXNet and our development guides wiki.