riiswa / kanrl

Kolmogorov-Arnold Network for Reinforcement Leaning, initial experiments
268 stars 33 forks source link

Add initial REINFORCE-kan code #8

Open db7894 opened 6 months ago

db7894 commented 6 months ago

Hello, not expecting/planning to merge right now but began writing up a basic take on KAN for reinforce based on the simplest possible version + one experiment, added efficient_kan as well for comparisons and a plot with results from one run (on my laptop, lol) with reinforce—results look pretty bad / a bit weird, and I only tested with 8 random seeds instead of 32 for a start. Planning to keep hacking away at this, but thought I'd open a draft PR in case you wanted to discuss extending to more algorithms, since I think some refactoring is probably a good idea for people who might want to add more.

riiswa commented 6 months ago

Thanks for this PR :D, and indeed I'd like to see other algorithms added, it would be interesting to refactor the repo to make it more flexible and composable!

It may not be easy to apply KANs in an online setting, which explains the rather special results... How much faster is your KAN implementation than the official one? Do you think it's possible to run a hyperparameter search with optuna or something else (in a reasonable amount of time)?

@corentinlger also works on Reinforce, maybe you can compare your results.

db7894 commented 6 months ago

Awesome, on refactoring: let me throw out another PR or at least make an issue with ideas once I've played around more. Some thoughts off the top of my head based on usage so far are:

On the special results, my MLP results also look really bad (to me at least), so maybe I've messed something else up 😅

For EfficientKAN I credited this repo which I'm currently using as-is—I haven't had a chance to explore whether it's possible to squeeze out more juice and haven't yet checked exact runtimes (but that is on my todo list!). I'll see if I can do a hyperparameter search with just the efficient version.. I'm just working on my macbook right now, but it was able to at least do the 8-seed multirun experiment so maybe it'll hold up!

db7894 commented 6 months ago

I wouldn't call this anything definitive since I haven't done hyperparameter sweeps or anything, but using standard values and trying MLP, KAN, and the efficient version (with 16 seeds this time), I'm seeing this rather interesting set of results.

carpole_mlp_kan_efficientkan

corentinlger commented 6 months ago

Hi, thanks for the PR ! Do you know why the episode length exceeds 500 in the results ?

I also implemented Reinforce with an MLP and a KAN, and got those results on 5 seeds (300_000 steps of training on CartPole-v1) :

reinforce_results

I agree some files could be refactored to facilitate integration of new algorithms in the repo. We can discuss both points if you want !

db7894 commented 6 months ago

Gotcha, are you using any bells and whistles or just standard reinforce? My second plot (above comment) was with rtg—I'm still not sure why the efficient_kan version didn't run for the full 500 episodes (and KAN didn't learn anything!).

As far as why the episode lengths are so long... I'm going to step back and just look at the MLP version to see what's up. I'm not really sure.

On refactoring: I posted some ideas in my last comment, and PR #11 is a first step which just switches to use a main experiment driver that can dispatch to other algorithm scripts. Let me know if you have any thoughts!

[ PS: you might have already seen this, but an interesting notebook on KAN/MLP ]

corentinlger commented 6 months ago

Oh actually I was talking about the Episode length of 500 on the y axis ahah (which can be misleading because you also train for 500 episodes). But I saw you exceed 500 steps/ep sometimes because you don't use the truncated flag of the environment in your code.

And yes actually I implemented a slightly different algorithm than Reinforce. This is still a simple policy gradient algorithm but I updated the network every n_steps (and then I reset the environment) instead of updating it at the end of each episode.

db7894 commented 6 months ago

Yeah, I understood that you meant the y axis haha—thanks, it slipped by me that I wasn't using truncation! And gotcha, I'll play around and see if doing that gives me different results.

yuzej commented 6 months ago

Gotcha, are you using any bells and whistles or just standard reinforce? My second plot (above comment) was with rtg—I'm still not sure why the efficient_kan version didn't run for the full 500 episodes (and KAN didn't learn anything!).

As far as why the episode lengths are so long... I'm going to step back and just look at the MLP version to see what's up. I'm not really sure.

On refactoring: I posted some ideas in my last comment, and PR #11 is a first step which just switches to use a main experiment driver that can dispatch to other algorithm scripts. Let me know if you have any thoughts!

[ PS: you might have already seen this, but an interesting notebook on KAN/MLP ]

seems like KAN plays better than MLP