fff-rs / juice

The Hacker's Machine Learning Engine
1.1k stars 77 forks source link

Coaster convolution API cleanup #178

Closed hweom closed 1 year ago

hweom commented 1 year ago

What does this PR accomplish?

Refactor Convolution Coaster API and put workspace into ConvolutionConfig (which is renamed to ConvolutionContext).

Previously, workspace was opaque to the users of the API as they had to construct it and pass to API functions but had no use of it otherwise. And workspace is an internal details of CUDA as for example native implementation doesn't use it (although native implementation is currently incomplete).

This is in preparation for reimplementing the convolution layer in the new arch.

Changes proposed by this PR:

Notes to reviewer:

Tests are not yet changed, I'll update them if you're OK with the overall idea.

Verified that mnist conv runs as before and converges to 95% accuracy.

📜 Checklist