deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.12k stars 654 forks source link

How can I set layer's parameters by myself like tensorflow? #1447

Closed dahongdou111 closed 2 years ago

dahongdou111 commented 2 years ago

now, I want to implement a reinforcement learing method--DDPG. It has eval network and target network, their have the same structure. but eval network through gradient update, target network need update by hand(like this, ta = (1-TAU)ta + TAUea). How can I implement this by DJL? ` # networks parameters self.ae_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/eval') self.at_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/target') self.ce_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/eval') self.ct_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/target')

    TAU = 0.01
    # target net replacement
    self.soft_replace = [[tf.assign(ta, (1 - TAU) * ta + TAU * ea), tf.assign(tc, (1 - TAU) * tc + TAU * ec)]
                         for ta, ea, tc, ec in zip(self.at_params, self.ae_params, self.ct_params, self.ce_params)]`
frankfliu commented 2 years ago

@dahongdou111

Yes, you can access model's parameter, see: https://d2l.djl.ai/chapter_deep-learning-computation/parameters.html

https://d2l-zh.djl.ai/chapter_deep-learning-computation/parameters.html

dahongdou111 commented 2 years ago

@frankfliu Yes, I have seen this section. But I can only access model's parameters and can't set the parameters more than once. Because the Parameter class's setArray() function has a restriction. So how can I set the parameters multiple times?

public void setArray(NDArray array) {
        if (this.shape != null) {
            throw new IllegalStateException("array has been set! Use either setArray or setShape");
        } else {
            this.array = array;
            this.shape = array.getShape();
            array.setName(this.name);
        }
    }

    public void setShape(Shape shape) {
        if (this.array != null) {
            throw new IllegalStateException("array has been set! Use either setArray or setShape");
        } else {
            this.shape = shape;
        }
    }
frankfliu commented 2 years ago

@dahongdou111 Why you need to set it multiple time?

If you want to modify the value, you can directly update the NDArray: NDArray.set(Buffer data), or copy the value from another NDArray: NDArray.copyTo(array)

dahongdou111 commented 2 years ago

@frankfliu Because DDPG has two same networks. The target network needs update by hand, so I need to set the target network parameters multiple times.

zachgk commented 2 years ago

The parameter contains a mutable array, so you can use parameter.getArray().set(...) to apply your updates

dahongdou111 commented 2 years ago

Yeah, I have used NDArray.copyTo(array) to apply my updates.