Closed tingwl0122 closed 6 months ago
will add unit test python script soon.
Hi @jiaqima @TheaperDeng, I have completed a wrapper function to automatically choose the type of projector.
This will directly produce the projected matrix for the user based on her settings. (get_projection
)
Please take a look if you have time.
Some small utility functions from TRAK can potentially be merged intodattri/func/util.py
.
I skipped the ruff rule PLC0415
, which forces import XYZ
to be on the top of the file. Since in random_projection.py
, we will do optional importing import fast_jl
if users have CUDA devices.
Do we still need those chunk size and vectorize functions here given the flatten params function implemented by @TheaperDeng ?
Do we still need those chunk size and vectorize functions here given the flatten params function implemented by @TheaperDeng ?
Oh yes, the vectorize is useless. But for chunk size one, it is a criteria switching between cudaprojector and chunkedcudaprojector, which is independent of whether the input is flattened or not.
@tingwl0122 could you make the description of this PR a bit more informative before merging it? Maybe following the PR template.
@tingwl0122 could you make the description of this PR a bit more informative before merging it? Maybe following the PR template.
sure, will do.
close and re-open to re-run pytest.
Hi @TheaperDeng,
I will fix the _vectorize
dependency issue in the following PR, and this can be merged.
The current use case will be
# to mimic gradient of a torch model
small_gradient = {}
for name, p in self.small_model.named_parameters():
small_gradient[name] = torch.rand(test_batch_size, p.numel())
# suppose to be BasicProjector
project_func = random_project(small_gradient,
test_batch_size, self.proj_dim,
self.proj_max_batch_size, device="cpu",
proj_seed=0, use_half_precision=True)
projected_gradient = project_func(small_gradient)
Description
1. Motivation and Context
To transfer the TRAK random projectors here and add a wrapper function to help users directly get the suitable projector type.
2. Summary of the change
This PR adds three projector classes (from TRAK and
fast_jl
library), a wrapper functionget_projection
to automatically choose the right type of projector for users, and also some smaller helper functions from TRAK.Note 1: The helper functions (
parameters_to_vector
,get_num_params
) will be removed in the next PR.get_parameter_chunk_sizes
can possibly be merged intodattri.func.utils
.Note 2: The dependency of
vectorize
will be replaced byflatten_params
in the next PR. Note 3: The notationgrad_dim
and related notation will be replaced byfeature_dim
in the next PR.3. What tests have been added/updated for the change?