LSSTDESC / RESSPECT

The RESSPECT project is a result from an inter-collaboration agreement established between the Cosmostatistics Initiative (COIN) and the LSST Dark Energy Science Collaboration (DESC) with the goal of developing a recommendation system for telescope resource allocation able to optimize photometric supernova cosmology anaylsis.
MIT License
1 stars 0 forks source link

Make CertaintySampling query strategy #87

Open drewoldag opened 22 hours ago

drewoldag commented 22 hours ago

This should be basically identical to the UncSampling class, but the dist value will be what is closest to the value of "anomaly". See the notes below:

def uncertainty_sampling(class_prob: np.array, test_ids: np.array,
                         queryable_ids: np.array, batch=1,
                         screen=False, query_thre=1.0) -> list:

    if class_prob.shape[0] != test_ids.shape[0]:
        raise ValueError('Number of probabiblities is different ' +
                         'from number of objects in the test sample!')

    # calculate distance to the decision boundary - only binary classification
    dist = abs(class_prob[:, 1] - 0.5) **#! Change this to be which ever is closest to 1 (or 0, which ever is anomaly)**

    # get indexes in increasing order
    order = dist.argsort()

    # only allow objects in the query sample to be chosen
    flag = list(pd.Series(data=test_ids[order]).isin(queryable_ids))

    # check if there are queryable objects within threshold
    indx = int(len(flag) * query_thre)

    if sum(flag[:indx]) > 0:

        # arrange queryable elements in increasing order
        flag = np.array(flag)
        final_order = order[flag]

        if screen:
            print('\n Inside UncSampling: ')
            print('       query_ids: ', test_ids[final_order][:batch], '\n')
            print('   number of test_ids: ', test_ids.shape[0])
            print('   number of queryable_ids: ', len(queryable_ids), '\n')
            print('   *** Displacement caused by constraints on query****')
            print('   0 -> ', list(order).index(final_order[0]))
            print('   ', class_prob[order[0]], '-- > ', class_prob[final_order[0]], '\n')

        # return the index of the highest uncertain objects which are queryable
        return list(final_order)[:batch]

    else:
        return list([])