srush / llama2.rs

A fast llama2 decoder in pure Rust.
MIT License
1.01k stars 56 forks source link

Update export.py #31

Closed echosprint closed 1 year ago

echosprint commented 1 year ago

torch.zeros(..., dtype=int) will create a tensor with dtype torch.int64, we want a tensor of int32, need to specify the datatype torch.int