Open jaiswalsuraj487 opened 10 months ago
@jaiswalsuraj487 Now that our plan is broadened, let's not use distil
library. Use your own implementation. Can you visually show if your acquisition is picking the correct points?
@patel-zeel I have made the required changes as per the current version of sustainability-lab:main
and added sandbox/diveristy_acquisition_demo.ipynb
to show a visual of selected data points using corresponding acquisition functions on dummy data.
@patel-zeel Added AL notebook for diversity acquisitions notebooks/al/diversity_acq_AL.ipynb
Implemented Furthest Acquisition and Centroid Acquisition on commit a6e59ff.
Files Added:
astra/torch/al/acquisitions/furthest.py
: contain implementation of Furthest acquisitionastra/torch/al/acquisitions/centroid.py
: contain implementation of Centroid acquisitionastra/torch/al/strategies/diversity.py
: modified this file as per the needtests/torch/acquisitions/test_furthest.py
: contains test for furthest acquisition function.tests/torch/acquisitions/test_centroid.py
: contains test for centroid acquisition function.Passes all test cases, including those already existing(commit: a6e59ff).
Explanation:
furthest.py
: For the furthest acquisition function, we use the furthest_first method of Classdistil.active_learning_strategies.core_set.CoreSet
link where we pass dummy object strategy as an argument along with labeled_embeddings, unlabeled_embeddings and n. This returns list of indices of n data points that are furthest from all.centroid.py
: For the centroid acquisition function: For the Centroid Acquisition function, we pass labeled_embeddings, unlabeled_embeddings , and n as input.Below lines initializes min_dist as tensor with all values infinity of size [len(n_pool)] when our n_train is 0.
Else we find centroid of train data and then pairwise distance between centroid and all pool data.
We find index of n points from pool data, which has max distance.
diversity.py
: Since the acquisition function implemented in link takes (unlabeled_embeddings, labeled_embeddings, n) as parameters, I did same and modifieddiversity.py
instead of using(features, pool_indices, context_indices)
suggested indiversity.py
ofsustainability-lab/ASTRA
test_furthest.py
andtest_centroid
: Used CIFAR10 to test. Here we want to pass features extractor of model instead of forward pass of model, so I implemented feature extractor as below:def extract_features(net): def feature_extractor(input_tensor):
Initialize features with the input tensor
Create a feature extractor callable from the network
feature_extractor = extract_features(net)
We then pass this feature_extractor in
strategy.query()
which givesbest_indices
based on furthest or centroid acquisition provided.