hjsuh94 / score_po

Score-Guided Planning
10 stars 0 forks source link

Architecture Study for Noise Conditioning #19

Open hjsuh94 opened 1 year ago

hjsuh94 commented 1 year ago

How do we actually train a Noise Conditioned Score Estimator?

The most straightforward / naïve way would be to append one more dimension to the input of the network. (do MLP(4, 3) instead of MLP(3,3)).

It seems that for the diffusion papers, there were some interesting choice of architectures we should consider adopting. In the original implementation this repo, it's interesting to see that the value of sigma is never really used as an input to the network.

Instead, if we look at this implementation of anneal_dsm_score_estimation, the labels are defined as integer variables here. So this network takes in as input integers, as opposed to actual values of sigmas.

After getting this integer label, data and label are both passed into ConditionalResidualBlock, which normalizes the (data, label) pairs.

This normalization is done by using nn.Embedding to convert the integer token label as some feature vector. Then, this feature vector is multiplied with the batch-normalized data. This is done multiple times through a deep layer.

hongkai-dai commented 1 year ago

Hmm I am confused by reading the functions in https://github.com/ermongroup/ncsn/blob/master/models/scorenet.py, as you said, the sigma is never passed as an input. So they just use one score function (independent of sigma) to approximate the gradient, namely

1/2 Eₚ(x) E_x̃∼N(x, σ²I)[|s_θ(x̃) +(x̃−x)/σ²|²]

where the score function s_θ(x̃) doesn't depend on the noise σ? That is really strange as I thought in the paper one of the enabler of training the score network is to have one network for different level of σ.

hjsuh94 commented 1 year ago

I think most of the classes in that particular module are made for training with single sigmas! The code in https://github.com/ermongroup/ncsn/blob/7f27f4a16471d20a0af3be8b8b4c2ec57c8a0bc1/models/cond_refinenet_dilated.py is the ones that does noise-conditioning.

But I do wonder what the implication is for our modules: should we train with single level of noise for now?