FluxML / Torch.jl

Sensible extensions for exposing torch in Julia.
Other
211 stars 14 forks source link

Updated C wrapper wrt. Torch v1.10 #61

Closed stemann closed 1 month ago

stemann commented 1 year ago

Updates the C wrapper based on ocaml-torch @ 0.14 - matching Torch v1.10 (current JLL-build)

Contributes to #54 - follow-up for #56

Notable included changes:

The last two changes could be moved to a separate PR (to reduce number of changes in this PR).

To-do:

zsz00 commented 11 months ago

What's the status now?

stemann commented 11 months ago

It's been a idle for a while, but status is summarised by these comments:

stemann commented 10 months ago

See #54 for status.

DhairyaLGandhi commented 4 months ago

Perhaps it would make sense to start merging some of the excellent changes here?

stemann commented 4 months ago

Yes! :-)

Please give torch_api.{cpp,h} a thorough review - I have made some changes in an effort to make things a bit more consistent - e.g. wrt. always returning an int status code.

Edit: I'll try to go over it as well and try to make a re-cap of the changes.

stemann commented 4 months ago

This is the current main diff which covers the hand-written part (torch_api.{cpp,h}): https://github.com/FluxML/Torch.jl/compare/7828132d..ece96546

stemann commented 4 months ago

The overall aim was to to update for Torch v1.10.2 - but also to make it easier to apply a diff of changes for subsequent version updates...

stemann commented 4 months ago

Recap:

General

Modified function definitions

int at_float_vec(double *values, int value_len, int type);
int at_int_vec(int64_t *values, int value_len, int type);
int at_grad_set_enabled(int);
int at_int64_value_at_indexes(double *i, tensor, int *indexes, int indexes_len);
tensor at_load(char *filename);
int ato_adam(optimizer *, double learning_rate,
                   double beta1,
                   double beta2,
                   double weight_decay);
int atm_load(char *, module *);

Added function definitions

int at_is_sparse(int *, tensor)
int at_device(int *, tensor)
int at_stride(tensor, int *)
int at_autocast_clear_cache();
int at_autocast_decrement_nesting(int *);
int at_autocast_increment_nesting(int *);
int at_autocast_is_enabled(int *);
int at_autocast_set_enabled(int *, int b);
int at_to_string(char **, tensor, int line_size)
int at_get_num_threads(int *);
int at_set_num_threads(int n_threads);
int ati_none(ivalue *);
int ati_bool(ivalue *, int);
int ati_string(ivalue *, char *);
int ati_tuple(ivalue *, ivalue *, int);
int ati_generic_list(ivalue *, ivalue *, int);
int ati_generic_dict(ivalue *, ivalue *, int);
int ati_int_list(ivalue *, int64_t *, int);
int ati_double_list(ivalue *, double *, int);
int ati_bool_list(ivalue *, char *, int);
int ati_string_list(ivalue *, char **, int);
int ati_tensor_list(ivalue *, tensor *, int);
int ati_to_string(char **, ivalue);
int ati_to_bool(int *, ivalue);
int ati_length(int *, ivalue);
int ati_to_generic_list(ivalue, ivalue *, int);
int ati_to_generic_dict(ivalue, ivalue *, int);
int ati_to_int_list(ivalue, int64_t *, int);
int ati_to_double_list(ivalue, double *, int);
int ati_to_bool_list(ivalue, char *, int);
int ati_to_tensor_list(ivalue, tensor *, int);
stemann commented 2 months ago

@DhairyaLGandhi Do you know if anyone is available for reviewing these changes?

I'm at JuliaCon, FYI.