Qulacs-Osaka / scikit-qulacs

scikit-qulacs is a library for quantum neural network. This library is based on qulacs and named after scikit-learn.
https://qulacs-osaka.github.io/scikit-qulacs/index.html
MIT License
21 stars 6 forks source link

並列分散学習できるようにしたい #198

Open forest1040 opened 2 years ago

forest1040 commented 2 years ago

学習に時間がかかる場合があるため、藤井研のクラスタを使用して、並列分散学習できるようにしたい。 学習した結果(パラメータ)を各ノードから集めて、平均化して各ノードに更新する必要がある。

ikanago commented 2 years ago

numpy を置き換えて GPU や vectorization で高速化する jax というライブラリがあります. クラスタに対応してるかは分からないのですが,CPU の環境でも置き換えるだけで速くなるかも含めて今度調べてみます.

wmizukami commented 2 years ago

ご存知かと思いますが、jax + MPI の実装として、MPI4JAXがあります。そちらの性能など興味があるので、もし使われましたら情報ご共有いただければ。

ikanago commented 2 years ago

ありがとうございます. scikit-qulacs の JAX 版を別で作ると大変なので,高速化できることが確認できたら numpy を完全に置き換えてもよさそうです.

forest1040 commented 2 years ago

並列分散学習なのですが、データ並列性とモデル並列性があると思っています。 https://tech.preferred.jp/ja/blog/model-parallelism-in-dnn/

データ並列性は、同じデータセット(別でもいいでしょうけれど。)を複数ノードでノード単体で学習させて 学習結果を集めて平均取って各ノードに反映して、また分散学習するという流れだと思います。

モデル並列性は、ノード全体のメモリ共有型で、1つの巨大なモデルを1つのデータセットで、学習する形だと思います。 MPI4JAXは、こちら側なのかなと思いました。(もちろんデータ並列性もできるでしょうけど。。)

MPI4JAXを少しだけ見たのですが、Jax配列がzero copyだと書いてあったので、そこがおもしろそうです。 多分、mpiのrecvとsendバッファにコピーせずに、通信ができることを言っているんだと思います。 ちょんとコードは読んでいませんが。。 あと、XLAに対応しているので、GPUの他にもいろいろなデバイスを使って、今後計算できそうです。(TPUとか)

ikanago commented 2 years ago

ただ,scikit-qulacs では入出力のデータ表現として numpy を使っているだけで,行列演算は qulacs-osaka が Eigen で行っているので,jax を使うメリットが薄いようにも感じました.

forest1040 commented 2 years ago

確かに!まぁ、ちょっと並列分散しようと思うと、もう少し構成を考えないといけないですね。。