KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.76k stars 1.35k forks source link

add plt.savefig(f'{folder}/'+file_name) in plot()(Kan.py) and save fig #114

Closed Simba1999 closed 2 months ago

Simba1999 commented 5 months ago

in line 758, above train() add plt.savefig(f'{folder}/'+file_name) def plot(self, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, file_name="file"):

and in def you should add file_name="file" after title def plot(self, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, file_name="file"):

then you can save fig in ./figures model.plot(file_name="org")

and I # plt.gca().spines[:].set_color(color) TypeError: unhashable type: 'slice'

JinglaiZheng commented 5 months ago

Could you make that clearer? I'm struggling to draw a KAN diagram, thank you very much

JinglaiZheng commented 5 months ago

Sorry to bother you. I konw how to do now! Just add file_name="file" in def plot(self, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None) as an attribute. Then add plt.savefig(f'{folder}/'+file_name) in line 758 (Indent eight Spaces in front)

Simba1999 commented 5 months ago

yep🤝

KindXiaoming commented 5 months ago

Hi, Another user run into the same problem, the problem seems to be your matplotlib version. It is suggested that matplotlib==3.6.2