sustainability-lab / ASTRA

"AI for Sustainability" Toolkit for Research and Analysis
1 stars 6 forks source link

Diversity Acquisition Functions #12

Open jaiswalsuraj487 opened 10 months ago

jaiswalsuraj487 commented 10 months ago

Implemented Furthest Acquisition and Centroid Acquisition on commit a6e59ff.

Files Added:

  1. astra/torch/al/acquisitions/furthest.py: contain implementation of Furthest acquisition
  2. astra/torch/al/acquisitions/centroid.py: contain implementation of Centroid acquisition
  3. astra/torch/al/strategies/diversity.py: modified this file as per the need
  4. tests/torch/acquisitions/test_furthest.py: contains test for furthest acquisition function.
  5. tests/torch/acquisitions/test_centroid.py: contains test for centroid acquisition function.

Passes all test cases, including those already existing(commit: a6e59ff).

Explanation:

  1. furthest.py: For the furthest acquisition function, we use the furthest_first method of Class distil.active_learning_strategies.core_set.CoreSet link where we pass dummy object strategy as an argument along with labeled_embeddings, unlabeled_embeddings and n. This returns list of indices of n data points that are furthest from all.
  2. centroid.py: For the centroid acquisition function: For the Centroid Acquisition function, we pass labeled_embeddings, unlabeled_embeddings , and n as input.

Below lines initializes min_dist as tensor with all values infinity of size [len(n_pool)] when our n_train is 0.

    if labeled_embeddings.shape[0] == 0:
        min_dist = torch.full((unlabeled_embeddings.shape[0],), float("inf"))

Else we find centroid of train data and then pairwise distance between centroid and all pool data.

    else:
        centroid_embedding = torch.mean(labeled_embeddings, dim=0).unsqueeze(0)
        dist_ctr = torch.cdist(unlabeled_embeddings, centroid_embedding, p=2)
        min_dist = torch.min(dist_ctr, dim=1)[0]

We find index of n points from pool data, which has max distance.

    idxs = []
    for i in range(n):
        idx = torch.argmax(min_dist)
        idxs.append(idx.item())
        dist_new_ctr = torch.cdist(unlabeled_embeddings, unlabeled_embeddings[[idx], :])
        min_dist = torch.minimum(min_dist, dist_new_ctr[:, 0])
    return idxs
  1. diversity.py: Since the acquisition function implemented in link takes (unlabeled_embeddings, labeled_embeddings, n) as parameters, I did same and modified diversity.py instead of using (features, pool_indices, context_indices) suggested in diversity.py of sustainability-lab/ASTRA
  2. and 5. test_furthest.py and test_centroid: Used CIFAR10 to test. Here we want to pass features extractor of model instead of forward pass of model, so I implemented feature extractor as below:
    
    # Define the model
    net = CNN(32, 3, 3, [4, 8], [2, 3], 10).to(device)

def extract_features(net): def feature_extractor(input_tensor):

Initialize features with the input tensor

    features = input_tensor

    # Apply each layer, activation, and max-pooling
    for layer in net.feature_extractor:
        features = layer(features)
        features = net.activation(features)
        features = net.max_pool(features)
    features = net.flatten(features)
    return features

return feature_extractor

Create a feature extractor callable from the network

feature_extractor = extract_features(net)

This feature_extractor gives us features ie. embedding of input.  
```python
# example: this snippet is not included in code
# input shape: (data_dim, height, width, channels)
input = input.permute(0, 3, 1, 2) #input shape: (data_dim, channels, height, width)
features = feature_extractor(input) # shape (data_dim, feature_dim)

We then pass this feature_extractor in strategy.query() which gives best_indices based on furthest or centroid acquisition provided.

# Query the strategy
best_indices = strategy.query(
    feature_extractor, pool_indices, train_indices, n_query_samples=n_query_samples
)
patel-zeel commented 10 months ago

@jaiswalsuraj487 Now that our plan is broadened, let's not use distil library. Use your own implementation. Can you visually show if your acquisition is picking the correct points?

jaiswalsuraj487 commented 10 months ago

@patel-zeel I have made the required changes as per the current version of sustainability-lab:main and added sandbox/diveristy_acquisition_demo.ipynb to show a visual of selected data points using corresponding acquisition functions on dummy data.

jaiswalsuraj487 commented 10 months ago

@patel-zeel Added AL notebook for diversity acquisitions notebooks/al/diversity_acq_AL.ipynb