Closed nissy-dev closed 4 years ago
I try some tutorials about JAX and I get used to JAX. I categorized them into required ones and optional ones. My notebook is here
jit
, grad
and vmap
vmap
is really interestingGradients
section at the beginning.Write the log-joint function for the model
section.I will get used to JAX with GNN.
3. JAX implementation
section. Today, I think it is no bockers because today's task is just inputting knowledge.
I tried a GCN tutorial about JAX and read mnist code about dm-haiku. Haiku is a simple and light NN library and I will try to use with JAX-chem My notebook is here
I implement graph property prediction model by refactoring yesterday's tutorial GCN. (Tutorial builds node classification model about Cora dataset.)
I implemented graph property prediction model by refactoring yesterday's tutorial GCN. I made the draft PR. https://github.com/deepchem/jaxchem/pull/3
Mainly, I implement training codes with Tox21.
Mainly, I implemented GCN codes with Tox21.
Mainly, I implement GCN codes with Tox21.
PS. I finished implementation about GCN with Tox21 and rethink my plan. Please confirm this issue https://github.com/deepchem/jaxchem/issues/1
Today, I think it is no bockers
I implemented GCN codes with Tox21.
Mainly, I setup test environment and look into pytest. Before refactoring with Haiku, I will just write shape tests. After refactoring with Haiku and officially deciding to use Haiku in jaxchem, I will write another tests.
Today, I think it is no bockers
I setup some environments and implemented normalization of adjacency matrix for GCN
Mainly, I will write shape test for GCN
Today, I think it is no bockers
I wrote shape test for GCN
Mainly, I will refactor GCN model using Haiku. Today, I knew jax-md uses haiku.
I think it is no bockers. But, this implementation is a little hard. Maybe I can't finish in one day.
I took the day off.
Mainly, I will refactor GCN model using Haiku.
I think it is no bockers. But, this implementation is a little hard. Maybe I can't finish in one day.
I finished refactoring GCN model using Haiku.
Mainly, I will implement sparse pattern GCN model. Previous models's input is adjacency matrix, but this model's input is adjacency list. Adjacency list is more memory efficient compared with adjacency matrix. I posted the next detail plan.
I think it is no bockers.
I Implemented the sparse pattern GCN model.
I Implemented the sparse pattern GCN model and example. I faced the performance issue and I'm struggling resolving it.
I think it is no bockers.
I couldn't work for jaxchem in 6/23, 6/24 because of my research tasks. I started resolving the performance issue from yesterday.
I'm struggling resolving the performance issue.
The reason of the performance issue is related to google/jax#2242.
SparseCGN uses jax.ops.index_add
, but a large Python "for" loop leads to a serious performance issue when using jax.ops.index_add
(training time of this example is 24 times than the example of PadGCN)
I try to rewrite training loop codes using lax.scan
or lax.fori_loop
. However, the generator/iterator doesn't work in lax.scan
or lax.fori_loop
, so I take much time than expected. (I made the issue in jax repo, https://github.com/google/jax/issues/3567)
If I'm not able to use generator/iterator, maybe I have to write some codes which convert DiskDataset to original Dataset. The plan will be delayed for a few days. The following plan is the worst case. I want to post blog by the first evaluation term's end (7/4).
I think it is no bockers.
I was resolving the performance issue, but I couldn't.
Today, I write the summary and issues and rethink the plan.
In this week, I decide to focus on writing summary and issue details for the GSoC evaluation.
I will also make colab notebook example and post blogs for DeepChem Forum.
Updated Plan
I think it is no bockers.
I spent four weeks in joining the DeepChem as a GSoC student and the 1st evaluation has come! I want to explain what I did in four weeks.
As I mentioned in this roadmap https://github.com/deepchem/jaxchem/issues/1, I tried to implement GCN models and make tutorials during 1st evaluation period. The reason why I chose this topic is that the GCN(GNN) is the most popular method as an example of deep learning in the area of chemistry. I think this is a good starting point for JAXChem.
During 1st evaluation period, I implemented the two pattern GCN model.
If you want to confirm the details about the difference between two models, please check the roadmap https://github.com/deepchem/jaxchem/issues/1.
One of the challenging point of JAXChem is to implement the sparse pattern GCN model. Pad pattern modeling is easier and the blog was published like this.
While implementing these models, I modified the roadmap in June (https://github.com/deepchem/jaxchem/issues/1#issuecomment-640702910) by following some advices and I prioritized to make our codes more readable and maintainable. I listed up what I did.
init_func
of archive models archive/gcn)init_func
disrupts our focus on forward computation implementationI found the performance issue about the sparse pattern GCN model when making the Tox21 example.
The reason of the performance issue is related to google/jax#2242. The sparse pattern GCN model uses jax.ops.index_add
, but a large Python "for" loop leads to a serious performance issue when using jax.ops.index_add
(Training time/epoch of the Tox21 example is 24 times than the pad pattern GCN model)
In order to resolve this issue, I have to rewrite training loop using lax.scan
or lax.fori_loop
. However, lax.scan
or lax.fori_loop
have some limitations like the generator/iterator doesn't work (See : google/jax#3567), so it is difficult to rewrite. Now, I'm struggling this issue and please confirm the more details in this issue https://github.com/deepchem/jaxchem/issues/8
According to the roadmap, I'm supposed to be working for implementing the CGCNN model. However, I will change this plan. In the next period, I will focus on resolving the performance issue and writing the documents. Please confirm the details below.
There are two reasons I will change the plan. First, the CGCNN model is similar to the sparse pattern model. Second, I seem that the crystal support of deepchem is currently too early stage and it still needs many fixes. On the other hand, I will not change the plan (Implementing the Molecular Attention Transformer) about final evaluation period.
My official project is JAXChem, but I also have committed to DeepChem core codes. The reason is that JAXChem is one of the DeepChem projects. I think DeepChem core codes’ improvement is a really important for many users to know the JAXChem project and think they want to use it.
During 1st evaluation period, I mainly cleaned up old documentations or build systems. I listed up what I did in the details.
matraj
or openmm
couldn't be imported in colabconda-forge
or deepchem
channel is better for the futureconda-forge/deepchem
Go to #3
Daily standup in June
I use this template https://www.range.co/blog/complete-guide-daily-standup-meeting-agenda
6/2
Yesterday
Today
I will get used to jax.
What is JAX
section.Blockers
Today, I think it is no bockers because today's task is just inputting knowledge.