helmholtz-analytics / heat

Distributed tensors and Machine Learning framework with GPU and MPI acceleration in Python
https://heat.readthedocs.io/
MIT License
210 stars 53 forks source link

[Feature-384 Sparse] `to_sparse()` method for `DNDarray` class #1054

Closed Mystic-Slice closed 1 year ago

Mystic-Slice commented 1 year ago

Related Pending work from https://github.com/helmholtz-analytics/heat/pull/1028

Feature functionality Implement a to_sparse() method for the DNDarray class to create a DCSR_matrix.

Ishaan-Chandak commented 1 year ago

Hi @Mystic-Slice I am interested to work on this issue.

Ishaan-Chandak commented 1 year ago

Hi @Mystic-Slice in this we basically have to convert DNDarray to DCSR_Matrix. Presently the array we have is a tensor so now even if a DNDarray is passed we should be able to convert it into DCSR_matrix. Is this the correct understanding of this issue or I have misunderstood anything?

Mystic-Slice commented 1 year ago

@Ishaan-Chandak Thanks for your interest! Yes We have torch.Tensors inside a DNDarray. I think this method can simply just call ht.sparse.sparse_csr_matrix on the individual tensors(dndarray.larray) with the appropriate is_split parameter value.

Also, make sure that the DNDarray is balanced before doing this since the sparse module for now does not handle imbalance between processes.

When you make a PR for this, try filling out the form provided so that the PR holds complete information about your change/addition.

Ishaan-Chandak commented 1 year ago

Hi @Mystic-Slice Thanks for the answer. I have one more question that where should I add this function i.e. in which file.

Mystic-Slice commented 1 year ago

It would make sense to add this in heat/core/manipulations.py

Ishaan-Chandak commented 1 year ago

Hi @Mystic-Slice just an update, I am reading the docs of heat for better understanding and i have managed to write some code. Currently I am not sure whether it is correct or not I will try to run that with some modifications and then send the final results here.

Ishaan-Chandak commented 1 year ago

Hi @Mystic-Slice can you help me find some test cases to help me check whether the program I have written is correct or not.

Mystic-Slice commented 1 year ago

Hey! @Ishaan-Chandak. Sorry for the late reply. You can test it using something like

arr = [ [0, 0, 1, 0, 2]
        [0, 0, 0, 0, 0]
        [0, 3, 0, 0, 0]
        [4, 0, 0, 5, 0]
        [0, 0, 0, 0, 6] ]

A = ht.array(arr, split=0) # DNDarray

B = ht.array(arr, split=0).to_sparse() # DCSR_matrix

assert (A.larray.to_sparse_csr() == B.larray)

I of course have not tested this code, but you must be able to verify the to_sparse() method this way.

Ishaan-Chandak commented 1 year ago

Hi @Mystic-Slice I am not sure why this error is popping up can you please help me resolving this issue.

Screenshot from 2023-03-06 17-58-34

Mystic-Slice commented 1 year ago

I am not completely sure, but it might be because you are inside the heat/heat/sparse/ folder. Try running this from the root folder of the project.

samadpls commented 1 year ago

Hey @Mystic-Slice, could you please assign this to me?

github-actions[bot] commented 1 year ago

Branch 1054-_Feature-384_Sparse_to_sparse_method_for_DNDarray_class created!

Mystic-Slice commented 1 year ago

Removing @Ishaan-Chandak due to inactivity.

@samadpls Please create a PR when you are done with this work. And feel free to ask questions if you have any.

samadpls commented 1 year ago

Hello @Mystic-Slice

I have submitted a Pull Request. Could you please review it so I can start writing test cases? Also, I am trying to run the code locally, but I'm facing a circular import issue. This is how I am testing it:

import heat as ht 
array = [[0, 0, 1, 0, 2],
       [0, 0, 0, 0, 0],
       [0, 3, 0, 0, 0],
       [4, 0, 0, 5, 0],
       [0, 0, 0, 0, 6]]

A = ht.array(array, split=0)
B = A.to_sparse()
assert (A.larray.to_sparse_csr() == B.larray)

I have already checked my imports and there doesn't seem to be any direct circular imports in the code I've provided. Could you please help me troubleshoot this issue?