Closed dahongdou111 closed 2 years ago
@dahongdou111
Yes, you can access model's parameter, see: https://d2l.djl.ai/chapter_deep-learning-computation/parameters.html
https://d2l-zh.djl.ai/chapter_deep-learning-computation/parameters.html
@frankfliu Yes, I have seen this section. But I can only access model's parameters and can't set the parameters more than once. Because the Parameter class's setArray() function has a restriction. So how can I set the parameters multiple times?
public void setArray(NDArray array) {
if (this.shape != null) {
throw new IllegalStateException("array has been set! Use either setArray or setShape");
} else {
this.array = array;
this.shape = array.getShape();
array.setName(this.name);
}
}
public void setShape(Shape shape) {
if (this.array != null) {
throw new IllegalStateException("array has been set! Use either setArray or setShape");
} else {
this.shape = shape;
}
}
@dahongdou111 Why you need to set it multiple time?
If you want to modify the value, you can directly update the NDArray: NDArray.set(Buffer data)
, or copy the value from another NDArray: NDArray.copyTo(array)
@frankfliu Because DDPG has two same networks. The target network needs update by hand, so I need to set the target network parameters multiple times.
The parameter contains a mutable array, so you can use parameter.getArray().set(...)
to apply your updates
Yeah, I have used NDArray.copyTo(array)
to apply my updates.
now, I want to implement a reinforcement learing method--DDPG. It has eval network and target network, their have the same structure. but eval network through gradient update, target network need update by hand(like this, ta = (1-TAU)ta + TAUea). How can I implement this by DJL? ` # networks parameters self.ae_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/eval') self.at_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/target') self.ce_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/eval') self.ct_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/target')