Open JWKKWJ123 opened 8 months ago
This is a novel and intriguing method: training an Energy-Based Model (EBM) as a Generalised Additive Model (GAM) inside a huge CNN or Transformer architecture. While training classic EBMs usually involves end-to-end optimisation techniques, it is possible to modify them to operate within a broader neural network architecture and train them incrementally (batch-by-batch).
Here is a code example of how you can achieve this:
import torch import torch.nn as nn import torch.optim as optim import numpy as np
class CNNWithEBM(nn.Module): def init(self): super(CNNWithEBM, self).init() self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.ebm = nn.Linear(32 14 14, 1) # Example linear EBM layer
def forward(self, x):
x = self.cnn(x)
x = x.view(x.size(0), -1) # Flatten the output
energy = self.ebm(x)
return energy
def generate_data(batch_size=32):
data = torch.randn(batch_size, 1, 28, 28) # MNIST-like data
labels = torch.randint(0, 2, (batch_size,))
return data, labels
model = CNNWithEBM()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10 batch_size = 32 for epoch in range(num_epochs): total_loss = 0.0 for batch_idx in range(num_batches):
data, labels = generate_data(batch_size)
# Zero the gradients
optimizer.zero_grad()
# Forward pass
energy = model(data)
# Compute loss
loss = criterion(energy.squeeze(), labels.float()) # Energy-based loss
# Backward pass
loss.backward()
# Update parameters
optimizer.step()
# Accumulate total loss
total_loss += loss.item()
# Print average loss for the epoch
print(f"Epoch {epoch + 1}, Avg. Loss: {total_loss / num_batches:.4f}")
I hope that this helps. Thank you
Dear Sunnycasmir, Thank you for your reply!
This is a novel and intriguing method: training an Energy-Based Model (EBM) as a Generalised Additive Model (GAM) inside a huge CNN or Transformer architecture. While training classic EBMs usually involves end-to-end optimisation techniques, it is possible to modify them to operate within a broader neural network architecture and train them incrementally (batch-by-batch).
Here is a code example of how you can achieve this:
import necessary libries import torch import torch.nn as nn import torch.optim as optim import numpy as np
define neural network architecture with an EBM layer class CNNWithEBM(nn.Module): def init(self): super(CNNWithEBM, self).init() self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.ebm = nn.Linear(32 14 14, 1) # Example linear EBM layer
def forward(self, x): x = self.cnn(x) x = x.view(x.size(0), -1) # Flatten the output energy = self.ebm(x) return energy
obtain synthetic dataset and define training loop
Generate synthetic dataset
def generate_data(batch_size=32): # Generate random data and labels data = torch.randn(batch_size, 1, 28, 28) # MNIST-like data labels = torch.randint(0, 2, (batch_size,)) return data, labels
Instantiate the model
model = CNNWithEBM()
Define loss function (energy-based loss)
criterion = nn.MSELoss()
Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
Training loop
num_epochs = 10 batch_size = 32 for epoch in range(num_epochs): total_loss = 0.0 for batch_idx in range(num_batches): # Generate mini-batch data data, labels = generate_data(batch_size)
# Zero the gradients optimizer.zero_grad() # Forward pass energy = model(data) # Compute loss loss = criterion(energy.squeeze(), labels.float()) # Energy-based loss # Backward pass loss.backward() # Update parameters optimizer.step() # Accumulate total loss total_loss += loss.item() # Print average loss for the epoch print(f"Epoch {epoch + 1}, Avg. Loss: {total_loss / num_batches:.4f}")
I hope that this helps. Thank you
Dear Sunnycasmir, Thank you very much for your reply! More specifically, I want to use EBM (explainable boosting machine) as the output layer of a large CNN/transformer. I considered using EBM as a custom layer of torch, but this would make EBM untrainable. So my question is how to train EBM incrementally (batch-by-batch) as a custom layer of torch? I think the example code didn't solve this question.
Is it possible to see the code you are working on to see how I can contribute more
Hi all, I have some update this week: I think the main difficulty is the deep-learning models and GAMs (including EBM) have very different training strategies. The GAMs need to read all training data at once and update the weights of all shape functions in the residuals sequentially. And the deep-learning models need to take the training data in mini-batch because of the memory limit (I use batchsize of 4 now), and update the model step by step. I would like to use the EBM as the output block in a large end-to-end 3D CNN. Then the question will be: Can the EBM be progressively updated step by step (mini-batch by mini-batch) simultaneously with CNN? I am trying to use the ebm.merge() to train the EBM in batchs and it seems work with a large batch. This is the code that I put EBM in to a deep learning model, now I made EBM untrainable in a CNN, because I am going to alternatively train EBM and CNN:
class EBM_layer(nn.Module):
def __init__(self, **kwargs):
super(EBM_layer, self).__init__(**kwargs)
def forward(self, x, ebm):
x = x.detach().cpu().numpy()
output_pro_ebm = ebm.predict_proba(x)
output_pro_ebm = output_pro_ebm[:,1]
output_pro_ebm = torch.tensor(output_pro_ebm, requires_grad=True)
output_pro_ebm = output_pro_ebm.unsqueeze(1)
return output_pro_ebm
def forward(self, x,ebm): #now I train EBM and CNN alternatively, so I input a trained ebm to the model in each epoch
for i in range(0,N):
out = self.cnnlist[i](x)
out_all=torch.cat([out_all,out],1) #this is the concatenation of the feature extracted by multiple CNNs
out_pro = self.EBM_layer(out_all,ebm)
return out_pro
Hi @JWKKWJ123 -- This kind of federated learning approach isn't something that we support out of the box. You can kind of hack it as you've discovered using merge_ebms, but the implementation isn't ideal. At some point we'll provide a better interface for building EBMs one boosting round at a time, and from batches.
Your other point though about DNNs and EBMs (based on decision trees) is quite pertinent too though. The training strategies are quite different and it's not clear to me that bringing them together will result in an ideal union. An alternative approach that I might suggest would be to train the DNN as normal, then remove the last layer, and train the EBM to replace it on the now frozen DNN. Will this approach work for you?
Hi @JWKKWJ123 -- This kind of federated learning approach isn't something that we support out of the box. You can kind of hack it as you've discovered using merge_ebms, but the implementation isn't ideal. At some point we'll provide a better interface for building EBMs one boosting round at a time, and from batches.
Your other point though about DNNs and EBMs (based on decision trees) is quite pertinent too though. The training strategies are quite different and it's not clear to me that bringing them together will result in an ideal union. An alternative approach that I might suggest would be to train the DNN as normal, then remove the last layer, and train the EBM to replace it on the now frozen DNN. Will this approach work for you?
Dear Paul, Thank you very much for your reply! I'm glad I've made some progress now. I found it is possible to use merge.ebm() to train ebm in batch with DNN. But now I am using a huge DNN so I can just set the batchsize to 4, and these training strategy cannot work when batchsize < 10. So after trails and errors, I developed a new training strategy (figure below), which is the train the model alternatively in two stages. Now I use is in a case that take both take whole image (global) and image patches (local) as input, each path way in the end-to-end model is a CNN:
This training strategy works (I accidentally added the accuracy twice in the epoch between the two stages). It can provide the contributions of different pathways in a large composite DNN, without sacrificing performance:
Hi all, I want to use EBM as a GAM to replace the fully connected layer at the end of a large CNN/Transformer to get interpretable output. However, I need to train the EBM like a deep learning model, with mini batches of data as input. I would like to ask is it possible to train the model step by step (batch by batch) instead of use the end-to-end fit() function? Or are there some people already working on this? Yours Sincerely, Wenjie Kang