Closed hongruj closed 8 months ago
Could you please change bm.cls_jit(static_argnums=(2))
as bm.cls_jit(static_argnames=('gt'))
, then calll the function through l = self.train(X, Y, gt=gt)
? I don't know whether it will solve the issue.
maybe this argument should not be set to be static. Removing the static setting may solve the issue?
Thanks for your reply.
I have tried both the solutions, and found the the error reports are same, which said:
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax.
Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=2/0)>, None, None).
To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
This line l3 = ((out[idx:,i]-Y[i])**2).sum()
has a problem, which means idx = gt[I]
is dynamic. So I set gt
into static in the first try (@bm.cls_jit(static_argnums=(2))
) but it did work.
here, we cannot use the static argument, this is because your gt
is always changing, implying the function will frequently to be compiled. One way to solve the issue is using jax.lax.dynamic_slice
. You can check its usage on the JAX doc.
Another way is using the mask. XLA can only compile models with fixed shapes/sizes. You can keep all batches with the same sizes but using different masks (0 1
patterns). That is to say:
# slice each sample
for i in range(batch_size):
# get the index
idx = gt[i]
# slice the output and compare it with the target
l3 = ((out[idx:,i]-Y[i])**2).sum()
the for loop here does not rely on index of gt
and then slicing, instead each batch in out
multiplies a different mask.
Thanks for your advice.
I tried jax.lax.dynamic_slice
, but still struggling in the index. Here is my code
# slice each sample
for i in range(batch_size):
idx = jax.lax.dynamic_index_in_dim(gt, i)
l1 = jax.lax.dynamic_slice(out, (0,i), (idx,1)) #got error
idx
is dynamic. It reports:
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>, 1).
The net's output is time series, and the index in gt
is important to slice the period we need (each sample has different period). Is there another way to get the period?
jax.lax.dynamic_slice
cannot solve your problem when you use JIT. As I have said, XLA only compiles models with static shapes. The dynamic slicing implies the dynamic shape.
Thanks for your time and explanation. It took me a few moments to understand it. Masking is a good idea.
Hi, BrainPy team.
In the
loss()
, I am going to slice the net's output according to the pre-setted index that stored in 'gt' , but I met some errors. Here is the definition ofTrainer()
:here is the error:
Could you please give me some advice? Thank you so much.