Open Michael-F-Bryan opened 3 years ago
Some things that should be defined...
hotg_rune_core::Tensor<T>
) may have the following element types
Cow<'static, str>
)#[transform]
attribute to a Transform
trait implementation to indicate that your proc block can transform certain data
inputs
and outputs
arguments which specify which input and output types are supportedinputs
or outputs
will let the proc block infer what to use and will be roughly equivalent to writing #[transform(inputs = P[..], outputs = Q[..])]
for some input element type P
and output element type Q
f32[1, 2, 3]
a 1x2x3-element tesor of f32
'sf32
- shorthand for f32[1]
f32[_, 256, 256, 3]
- a Nx256x256x3 tensor of f32
s, where N
may take any non-negative valuef32[..]
- a tensor with any arbitrary number of dimensionsf32[2, ..]
- a tensor who's first dimension should have a value of 2, but which can have zero or more additional dimensions#[derive(ProcBlock)]
above a type definition and it'll implement the ProcBlock
trait for that type
#[arguments]
attribute to your type's impl block
#[argument]
attribute is treated as an argument and must have the signature fn(&mut self, &'static str) -> Result<(), impl Display>
An example of how you might implement a tokenizer proc block under this scheme:
/// A BERT tokenizer.
#[derive(ProcBlock)]
struct Tokenizer {
word_list: Vec<&'static str>,
}
#[arguments]
impl Tokenizer {
#[argument]
pub fn set_word_list(&mut self, value: &'static str) -> Result<(), Infallible> {
self.word_list = value.lines().map(|line| line.trim()).filter(|line| !line.is_empty()).collect();
Ok(())
}
fn tokenize(&self, sentence: &str) -> (Tensor<i32>, Tensor<i32>, Tensor<i32>) { ... }
}
#[transform(inputs = utf8, outputs = (i32[_], i32[_], i32[_]))]
impl Transform<Tensor<Cow<'static, str>> for Tokenizer {
type Output = (Tensor<i32>, Tensor<i32>, Tensor<i32>);
fn transform(&mut self, input: Tensor<Cow<'static, str>>) -> Self::Output {
assert_eq!(input.dimensions(), &[1], "This proc block only accepts a tensor containing a single string");
let sentence = input.get(&[0]).unwrap();
self.tokenize(sentence);
...
}
}
#[transform(inputs = u8[_])]
impl Transform<Tensor<u8>> for Tokenizer {
type Output = (Tensor<i32>, Tensor<i32>, Tensor<i32>);
fn transform(&mut self, input: Tensor<Cow<'static, str>>) -> Self::Output {
assert_eq!(input.dimensions().len(), 1, "This proc block only accepts 1D tensors");
let sentence: &[u8] = input.elements();
let sentence: &str = core::str::from_utf8(sentence).expect("The input was invalid UTF8");
self.tokenize(sentence);
...
}
}
In terms of documentation and examples, I think most of this would be done in doc-comments on the corresponding procedural macros. That way we can include loads of examples which cargo test
will automatically pick up and check for us.
After playing around with Forge a bit more, I think the extra type safety we get by Transform
being generic over its inputs is actually making our life harder and hurting the end user experience.
It's great to get errors from the compiler when you are writing Rust, but a typical Forge user is several steps removed from the Rust source code being compiled. Instead, we should aim for a single all-encompassing interface which takes a list of tensors as inputs and returns a list of tensors. The tensors should also do type checking internally instead of using a generic type parameter.
Among other things, this will let us remove the arbitrary restrictions on max inputs/outputs because they can be stored in a slice (e.g. &[Tensor]
) instead of needing to go through tuples and a trait that gets implemented for each arity. This arbitrary limit (previously 13) actually bit @Ge-te when he was implementing Tractable's rune.
Currently, proc blocks are mostly implemented on an ad-hoc basis with a lot of work left up to the Rust compiler to catch bugs.
We want to take advantage of Rust's procedural macros to enforce a consistent structure and generate metadata, then use WebAssembly custom sections to give external programs access to that metadata without needing to execute the unknown proc block. This is currently done using
#[derive(ProcBlock)]
.As it is, while we've been generating this metadata for a while, it isn't actually used by anything. That means the way a proc block is implemented and the way it is used can diverge, causing cryptic compilation errors because
rune build
blindly generates invalid code.There are roughly 3 pieces of information external programs need to know about a proc block:
width
) and can be set to a string (see #237)Later on, we may also include things a proc block requires from the runtime in order to run (e.g.
extern "C"
functions for hardware-accelerated operations).