google / trajax

Apache License 2.0
186 stars 23 forks source link

CEM #3

Open mkolodziejczyk-piap opened 2 years ago

mkolodziejczyk-piap commented 2 years ago

Hi, I noticed there are some problems with running CEM.

  1. Here the last index should be 7: https://github.com/google/trajax/blob/67fd5ed4867914f1ce1ab78819551c12206a7773/trajax/optimizers.py#L753 (similar in random shooting: https://github.com/google/trajax/blob/67fd5ed4867914f1ce1ab78819551c12206a7773/trajax/optimizers.py#L818)

  2. Default hyperparameters should be frozendict https://github.com/google/trajax/blob/main/trajax/optimizers.py#L685-L692

Would you like a PR for this?

Regards,

stephentu commented 2 years ago

Good catch! Feel free to send a PR.

kwesiRutledge commented 11 months ago

I suspect that I might be having a similar issue with using the cem and random_shooting methods. I receive the error:

Traceback (most recent call last):
  File "XX/trajax/tests/optimizers_test.py", line 850, in testCEM1
    X_opt, U_opt, obj, = optimizers.random_shooting(
  File "XX/trajax/trajax/optimizers.py", line 1160, in random_shooting
    controls = gaussian_samples(random_key, mean, stdev, control_low,
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 5) of type <class 'dict'> for function gaussian_samples is non-hashable.

when trying to use random_shooting. I will investigate and make a PR for this, if I can fix it quickly enough.

I'll attach the method that I added to OptimizersTest to demonstrate this behavior. testCEM1.txt