srush / llama2.rs

A fast llama2 decoder in pure Rust.
MIT License
995 stars 54 forks source link

Quick review #7

Open CodesInChaos opened 11 months ago

CodesInChaos commented 11 months ago
  1. The code and comment don't match (bytes vs codepoint):

    // first encode every individual byte in the input string for c in text.chars() {

  2. Is this really performance critical? Would the checked version have unacceptable performance?

    unsafe { from_utf8_unchecked(&buf).to_owned() }

    Also you can use String::from_utf8(_unchecked) to convert a Vec<u8> to String without reallocating.

    If the tokens aren't valid UTF-8 (e.g. because a code-point might be split into multiple tokens), then use Vec<u8>/[u8] instead of String/str.

  3. Why not use a map for this (HashMap or BTreeMap) instead of a linear search? That would get you down to O(1) or O(log(n)) instead of O(n).

    vocab.into_iter().position(|x| x == str)

  4. Why not (x + 1) / y?

    x / y + if x % y == 0 { 0 } else { 1 }

  5. Explain why this will always succeed in a comment or the expect message:

    let id = str_lookup(&str_buffer, self.vocab.as_slice()).expect("not good");

    Due to the codepoint/bytes confusion (point 1), this might actually fail.

  6. weights should not be a Box, since you don't own it, and since it doesn't originate from a box. You should use an &TWeights instead.

    Also add an assertion that the memory is aligned correctly.

  7. .unwrap() At minimum I'd replace most of these with expect, to improve error messages. Or switch the ones that are fallible to Result.

srush commented 11 months ago

Thanks! This is all really helpful. String stuff tripped me up a bit.

If you have a minute can you explain 6 to me. How do I check that it aligned? And how do I get a &TWeights from pointer without a box?

CodesInChaos commented 11 months ago

Something like &*(mmap.as_ptr() as *const TWeights) should work. I think this will produce a reference with an unconstrained lifetime, which acts similar to an &'static T.

To check the alignment, you could use:

assert_eq!(mmap.as_ptr().align_offset(Layout::new::<TWeights>()), 0);