o-jill / tigerdenversi

machine learning for ruversi.
0 stars 0 forks source link

重みファイルの読み込み #6

Closed o-jill closed 2 weeks ago

o-jill commented 1 month ago

safetensorsファイルではなくcsv形式のデータを読み込んで初期値としたい。

VarStore

pub struct VarStore {
    pub variables_: Arc<Mutex<Variables>>,
  // <-- 外から触れるはず
}
pub struct Variables {
    pub named_variables: HashMap<String, Tensor>,  // <-- 外から触れるはず
    pub trainable_variables: Vec<Var>,
}
o-jill commented 1 month ago

一旦こんな感じで行けそう。学習に適用できるかどうかは不明。 https://gist.github.com/o-jill/cae97f16465206e053a65927e384cd4d

o-jill commented 1 month ago

ファイルの読み込みのためにweight.rsを拝借して、計算は削除、読み書きと意図したレイヤの重みにアクセスできるとありがたい。なんならあるレイヤのTensorを返してしまってもいいかもしれない。 それぐらいならフルスクラッチでもいいかもしれない。 NNの構成とかもこっちでやってもいいのかもしれない。weight.rsというよりはevalnn.rsみたいな名前か?

o-jill commented 1 month ago

読みこんで設定出来たけど学習しても更新されてない模様。。。

o-jill commented 1 month ago

sfentensorsだとpanic!

o-jill commented 1 month ago

これか?insertじゃなくてコピーをしないと駄目の模様。

https://github.com/LaurentMazare/tch-rs/blob/fbb70396b155c7095927569e39f807f59a725dcb/examples/llama/main.rs#L376

    {
        let file = std::fs::File::open("llama.safetensors")?;
        let content = unsafe { memmap2::MmapOptions::new().map(&file)? };
        let safetensors = safetensors::SafeTensors::deserialize(&content)?;

        let mut variables = vs.variables_.lock().unwrap();
        for (name, var) in variables.named_variables.iter_mut() {
            let view = safetensors.tensor(name)?;
            let size: Vec<i64> = view.shape().iter().map(|&x| x as i64).collect();
            let kind: Kind = view.dtype().try_into()?;
            // Using from_blob here instead of from_data_size avoids some unnecessary copy.
            let src_tensor =
                unsafe { Tensor::from_blob(view.data().as_ptr(), &size, &[], kind, Device::Cpu) };
            var.f_copy_(&src_tensor)?;
        }
    }
o-jill commented 3 weeks ago

txtから読むときに、 layer1.biasとlayer2.weightのサイズが正しくない。 layer1.biasは[HIDDEN]で、layer2.weightは[1, HIDDEN]じゃないといけないとこが、 layer1.biasは[1, HIDDEN]で、layer2.weightは[HIDDEN]になっている。

コレを直すだけでinsert形式での値の更新で行けそう。 →差分が出てたのはヘッダの後ろに改行が抜けてたからだった。(別でチケを起こそう) あとは、epoch=0に指定して読み込んだ値のままtxtが吐き出されればOK。

o-jill commented 2 weeks ago

set_dataが良さそう。