brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
515 stars 93 forks source link

How to index and slice in the loss function during training #607

Closed hongruj closed 8 months ago

hongruj commented 8 months ago

Hi, BrainPy team.

In theloss(), I am going to slice the net's output according to the pre-setted index that stored in 'gt' , but I met some errors. Here is the definition of Trainer():

import jax
from functools import partial
import brainpy as bp
import brainpy.math as bm

class Trainer():
    def __init__(self, net, opt, data):
        self.net = net
        self.opt = opt
        self.opt.register_train_vars(net.train_vars().unique())
        self.grad = bm.grad(self.loss, grad_vars=net.train_vars().unique(), return_value=True)
        self.data = data

    def loss(self, X, Y, gt):
        # X: input of size (T, batch_size, num_in)
        # Y: target of size (batch_size)
        # gt: an array containing index indicating that from which to slice net's output, of size (batch_size) 

        # reset states
        self.net.reset(batch_size)

        # net's output of size (T, batch_size)
        out = bm.for_loop(self.net.step_run, (np.arange(X.shape[0]), X))       

        l=0

        # slice each sample
        for i in range(batch_size):   

            # get the index
            idx = gt[i]

            # slice the output and compare it with the target
            l3 = ((out[idx:,i]-Y[i])**2).sum()

            l = l + l3
        return l

    ##  gt is set to static, is it wrong O_O???
    @bm.cls_jit(static_argnums=(2))
    def train(self, X, Y, gt):
        grads, l = self.grad(X, Y, gt)
        self.opt.update(grads)
        return l

    def f_train(self, num_epoch):
        train_losses = []
        for i in range(num_epoch):
            t0 = time.time()            
            losses = []
            for X, Y, gt in self.data():            
                l = self.train(X, Y, gt)
                losses.append(l)
            print(f'Train {i} epoch, loss = {np.mean(losses):.4f}, used time {time.time() - t0:.4f} s')
            train_losses.extend(losses)
        return np.asarray(train_losses)

trainer = Trainer(net=net, opt=bp.optim.Adam(lr=2e-3), data=Data(X_data, target, gt))
losses_ = trainer.f_train(100)

here is the error:

Cell In[83], line 39, in Trainer.f_train(self, num_epoch)
     37 losses = []
     38 for X, Y, gt in self.data():            
---> 39     l = self.train(X, Y, gt)
     40     losses.append(l)
     41 print(f'Train {i} epoch, loss = {np.mean(losses):.4f}, used time {time.time() - t0:.4f} s')

File ~/anaconda3/envs/brainpy/lib/python3.10/site-packages/brainpy/_src/math/object_transform/jit.py:487, in _make_jit_fun.<locals>.call_fun(self, *args, **kwargs)
    485     args_, kwargs_, fun3 = args, kwargs, fun2
    486   with VariableStack() as stack:
--> 487     _ = jax.eval_shape(fun3, *args_, **kwargs_)
    488   del args_, kwargs_
    489 _transform = jax.jit(
    490   _make_transform(fun2, stack),
    491   static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums),
   (...)
    497   **jit_kwargs
    498 )

    [... skipping hidden 8 frame]

File ~/anaconda3/envs/brainpy/lib/python3.10/site-packages/brainpy/_src/math/object_transform/tools.py:48, in _partial_fun.<locals>.new_fun(*dynargs, **dynkwargs)
     46 i = 0
     47 for arg in static_args:
---> 48   if arg == empty:
     49     args.append(dynargs[i])
     50     i += 1

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Could you please give me some advice? Thank you so much.

chaoming0625 commented 8 months ago

Could you please change bm.cls_jit(static_argnums=(2)) as bm.cls_jit(static_argnames=('gt')), then calll the function through l = self.train(X, Y, gt=gt)? I don't know whether it will solve the issue.

chaoming0625 commented 8 months ago

maybe this argument should not be set to be static. Removing the static setting may solve the issue?

hongruj commented 8 months ago

Thanks for your reply.

I have tried both the solutions, and found the the error reports are same, which said:

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. 
Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=2/0)>, None, None). 
To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

This line l3 = ((out[idx:,i]-Y[i])**2).sum() has a problem, which means idx = gt[I] is dynamic. So I set gt into static in the first try (@bm.cls_jit(static_argnums=(2))) but it did work.

chaoming0625 commented 8 months ago

here, we cannot use the static argument, this is because your gt is always changing, implying the function will frequently to be compiled. One way to solve the issue is using jax.lax.dynamic_slice. You can check its usage on the JAX doc.

chaoming0625 commented 8 months ago

Another way is using the mask. XLA can only compile models with fixed shapes/sizes. You can keep all batches with the same sizes but using different masks (0 1 patterns). That is to say:

        # slice each sample
        for i in range(batch_size):   

            # get the index
            idx = gt[i]

            # slice the output and compare it with the target
            l3 = ((out[idx:,i]-Y[i])**2).sum()

the for loop here does not rely on index of gt and then slicing, instead each batch in out multiplies a different mask.

hongruj commented 8 months ago

Thanks for your advice.

I tried jax.lax.dynamic_slice, but still struggling in the index. Here is my code

# slice each sample
for i in range(batch_size):
    idx = jax.lax.dynamic_index_in_dim(gt, i)
    l1 = jax.lax.dynamic_slice(out, (0,i), (idx,1))   #got error

idx is dynamic. It reports:

 TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>, 1).

The net's output is time series, and the index in gt is important to slice the period we need (each sample has different period). Is there another way to get the period?

chaoming0625 commented 8 months ago

jax.lax.dynamic_slice cannot solve your problem when you use JIT. As I have said, XLA only compiles models with static shapes. The dynamic slicing implies the dynamic shape.

hongruj commented 8 months ago

Thanks for your time and explanation. It took me a few moments to understand it. Masking is a good idea.