harujoh / KelpNet

Pure C# machine learning framework
Apache License 2.0
242 stars 28 forks source link

Trainer クラスで使用される Array[] について #14

Closed harujoh closed 7 years ago

harujoh commented 7 years ago

以下はnyatlaさんのこちらの指摘を抜粋したものです https://github.com/harujoh/KelpNet/issues/13

Trainerクラスの以下の関数でちょっと気がかりなことがあります。

       public static double Train(FunctionStack functionStack, Array input, Array teach, ILossFunction lossFunction, bool isUpdate = true)

public static double Train(FunctionStack functionStack, Array[] input, Array[] teach, ILossFunction lossFunction, bool isUpdate = true)

についてですが、引数として多次元配列を引き渡した場合に、呼び出される関数に混乱があるように見えます。 具体的には、input=double[][]とteach=double[]を引き渡した場合に、後者が選択されずに前者が選択されることがありました。 この例では当然そうなることは予測できるのですが、問題に感じるのはArrayがArray[]とArrayを区別できない点です。 やや面倒ですが、train関数は数値型ごとに次元数を固定した関数をオーバロードしたほうがよいかもしれません。

harujoh commented 7 years ago

この件について問題の認識は有りました。

これを修正するのであれば Array[] を全て撲滅すべきだと考えているのですが、その対応方法が泥臭く全ての型を記述する以外の効果的な手法が思いつきませんでした。

また、泥臭く全て記述する場合においても、労力に見合った見返りが望めず、また可読性、メンテナンス性の低下が懸念されるため、問題として放置していたというのが正直なところです。

一旦ですが、対処法としてオーバーロードではなく関数名を BatchTrain に変更し、明示的に Array[] 側を呼び出させる様にしたいと思います。

nyatla commented 7 years ago

返信ありがとうございます。 最終的には、Arrayではなく、NdArrayを基本型に据えていけるとよさそうに思えます。 数値型についてはNdArrayに変換する過程ですべてdoubleに変換してしまうことで、問題を回避できると思います。 (JavaにはArrayのようなものがないので、今のところすべてNdArrayに整形して入力するようにしました。)

これに関連してですが、MnistDataのラッパーオブジェクトについて新規にissueを上げますので、このスレッドについて不要ならCloseしてください。

harujoh commented 7 years ago

更新のアナウンスをまだ流していないのですが、関数名をオーバーロードではなく BatchTrain に変更しております。 一度、ご確認をお願い致します。