allenhaozhu / protoLP

25 stars 2 forks source link

Imbalanced setting #4

Open 21Johnson21 opened 3 months ago

21Johnson21 commented 3 months ago

Hello author, how can I change the code in test_standard_GSSL_lapshot_unbalance.py to reproduce the performance in an imbalanced setting? Thank you for your work! Have a nice day☺

21Johnson21 commented 3 months ago

Hello author, how can I change the code in test_standard_GSSL_lapshot_unbalance.py to reproduce the performance in an imbalanced setting? Thank you for your work! Have a nice day☺

I used the following code to initiate the Imbalanced experiment, but the results I got were so different from those in your paper that I don't know what went wrong: 'FSLTask.loadDataSet("cub") FSLTask.setUnbalancedRandomStates(cfg) ndatas, labels = FSLTask.GenerateUnbalancedRunSet(start=0, end=n_runs, cfg=cfg)'

allenhaozhu commented 3 months ago

Hello author, how can I change the code in test_standard_GSSL_lapshot_unbalance.py to reproduce the performance in an imbalanced setting? Thank you for your work! Have a nice day☺

I used the following code to initiate the Imbalanced experiment, but the results I got were so different from those in your paper that I don't know what went wrong: 'FSLTask.loadDataSet("cub") FSLTask.setUnbalancedRandomStates(cfg) ndatas, labels = FSLTask.GenerateUnbalancedRunSet(start=0, end=n_runs, cfg=cfg)'

Sorry, my PC has been returned to ANU. So, I cannot find the original code. But I can give you a clue about that. In Realistic_Transductive_Few_Shot/src/datasets /sampler.py, use query_samples = get_dirichlet_query_dist(alpha, 1, self.n_cls, self.n_cls * self.q_shot)[0] to generate query_samples (alpha = 2 to follow the Neurips Setting). Please notice that in our code, we don't need a label because we put samples of different classes in different axes of the tensor. Thus, we need to generate labels at the same time, and the labels also are used in the validation. Sorry for the inconvenience, but it is a simple modification. Two things you need to do: generate query samples labels with this function, then make the labels explicitly in other functions.

allenhaozhu commented 3 months ago

Hello author, how can I change the code in test_standard_GSSL_lapshot_unbalance.py to reproduce the performance in an imbalanced setting? Thank you for your work! Have a nice day☺

I used the following code to initiate the Imbalanced experiment, but the results I got were so different from those in your paper that I don't know what went wrong: 'FSLTask.loadDataSet("cub") FSLTask.setUnbalancedRandomStates(cfg) ndatas, labels = FSLTask.GenerateUnbalancedRunSet(start=0, end=n_runs, cfg=cfg)'

I think the unbalanced code should be the same as the normal one; I gave the wrong copy one year before.

21Johnson21 commented 3 months ago

Hello author, You mean that in the normal version 'ndatas = FSLTask.GenerateRunSet(cfg=cfg)' ,the function GenerateRunSet uses get_dirichlet_query_dist() to generate the query sample, because the generated ndatas have different shapes such as (10000,5,16,512) axes represent different categories without the need to generate labels. At the same time, get_dirichlet_query_dist() is used to generate query labels in the validation. Just like 'labels =torch.arange(n_ways).view(1,1,n_ways).expand(n_runs,n_shot+n_queries,5).clone().view(n_runs, n_samples)', Is that right?

allenhaozhu commented 3 months ago

yes, then you need to modify any functions that explicitly enjoy the label information by different axis.

21Johnson21 @.***> 于2024年7月18日周四 00:56写道:

Hello author, You mean that in the normal version 'ndatas = FSLTask.GenerateRunSet(cfg=cfg)' ,the function GenerateRunSet uses get_dirichlet_query_dist() to generate the query sample, because the generated ndatas have different shapes such as (10000,5,16,512) axes represent different categories without the need to generate labels. At the same time, get_dirichlet_query_dist() is used to generate query labels in the validation. Just like 'labels =torch.arange(n_ways).view(1,1,n_ways).expand(n_runs,n_shot+n_queries,5).clone().view(n_runs, n_samples)', Is that right?

— Reply to this email directly, view it on GitHub https://github.com/allenhaozhu/protoLP/issues/4#issuecomment-2233528523, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMPJZOGON57QFEB7YXTDWDDZM2AY3AVCNFSM6AAAAABK4XCSF6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMZTGUZDQNJSGM . You are receiving this because you commented.Message ID: @.***>

21Johnson21 commented 3 months ago

Hello author, The experiment was set as miniimagenet(WRN-28-10), 5 ways 1shot. get_dirichlet_query_dist() is used to generate query samples and query labels, as shown in the figure. Regrettably, the highest accuracy of epochs was 64.58% and the final accuracy was 62.75%. I did look at functions that use labels. Perhaps I have missed something, which is far from the results in your paper. Do you have any good suggestions?

Snipaste_2024-07-18_11-20-01
allenhaozhu commented 3 months ago

Hello author, The experiment was set as miniimagenet(WRN-28-10), 5 ways 1shot. get_dirichlet_query_dist() is used to generate query samples and query labels, as shown in the figure. Regrettably, the highest accuracy of epochs was 64.58% and the final accuracy was 62.75%. I did look at functions that use labels. Perhaps I have missed something, which is far from the results in your paper. Do you have any good suggestions? Snipaste_2024-07-18_11-20-01

you could share me the modification then I can check the issue.

21Johnson21 commented 3 months ago
ndatas, labels = FSLTask.GenerateRunSet(cfg=cfg)  

def GenerateRunSet(start=None, end=None, cfg=None):
    global dataset, _maxRuns
    if start is None:
        start = 0
    if end is None:
        end = _maxRuns
    if cfg is None:
        cfg = {"shot": 1, "ways": 5, "queries": 15}

    setRandomStates(cfg)
    print("generating task from {} to {}".format(start, end))

    dataset = torch.zeros((end - start, cfg['ways'] * (cfg['shot'] + cfg['queries']), data.shape[2]))
    query_labels = torch.zeros((end-start, cfg['queries'] * cfg['ways']))
    support_labels = torch.zeros((end - start, cfg['shot'] * cfg['ways']))
    #get_dirichlet_query_dist(2, n_tasks, n_ways, q_shots)
    for iRun in range(end-start):
        dataset[iRun], support_labels[iRun], query_labels[iRun] = GenerateRun(start+iRun, cfg)
    labels = torch.cat((support_labels,query_labels ), dim=1).long()
    return dataset, labels

def GenerateRun(iRun, cfg, regenRState=False, generate=True):
    global _randStates, data, _min_examples
    if not regenRState:
        np.random.set_state(_randStates[iRun])

    classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
    shuffle_indices = np.arange(_min_examples)
    dataset = None
    if generate:
        dataset = torch.zeros((cfg['ways'], cfg['shot'], data.shape[2]))
        support_labels = torch.zeros(cfg['ways'], cfg['shot'], dtype=torch.int64)

    alpha = 2 * np.ones(cfg['ways'])
    query_samples = get_dirichlet_query_dist(alpha, 1, cfg['ways'], cfg['queries'] * cfg['ways'])[0]
    querySet = []
    labelSet = []
    for i in range(cfg['ways']):
        shuffle_indices = np.random.permutation(shuffle_indices)
        if generate:
            dataset[i] = data[classes[i], shuffle_indices, :][:cfg['shot']]
            support_labels[i] = i

            if query_samples[i] > data[classes[i], shuffle_indices, :].shape[0]:
                dist = query_samples[i] - data[classes[i], shuffle_indices, :].shape[0]
                query = data[classes[i], shuffle_indices, :][:query_samples[i]]
                query_extra = data[classes[i], shuffle_indices[:dist], :][:query_samples[i]]
                query = torch.cat((query, query_extra), dim=0)
            else:
                query = data[classes[i], shuffle_indices, :][:query_samples[i]]
            querySet.append(query)
            label_que = i * torch.ones(query_samples[i], dtype=torch.int64)
            labelSet.append(label_que)
    querys = torch.cat(querySet, dim=0)
    querys_labels = torch.cat(labelSet, dim=0)
    dataset = torch.cat((dataset.reshape(-1, data.shape[2]), querys), dim=0)
    support_labels = support_labels.reshape(-1)

    return dataset, support_labels, querys_labels
allenhaozhu commented 3 months ago
ndatas, labels = FSLTask.GenerateRunSet(cfg=cfg)  

def GenerateRunSet(start=None, end=None, cfg=None):
    global dataset, _maxRuns
    if start is None:
        start = 0
    if end is None:
        end = _maxRuns
    if cfg is None:
        cfg = {"shot": 1, "ways": 5, "queries": 15}

    setRandomStates(cfg)
    print("generating task from {} to {}".format(start, end))

    dataset = torch.zeros((end - start, cfg['ways'] * (cfg['shot'] + cfg['queries']), data.shape[2]))
    query_labels = torch.zeros((end-start, cfg['queries'] * cfg['ways']))
    support_labels = torch.zeros((end - start, cfg['shot'] * cfg['ways']))
    #get_dirichlet_query_dist(2, n_tasks, n_ways, q_shots)
    for iRun in range(end-start):
        dataset[iRun], support_labels[iRun], query_labels[iRun] = GenerateRun(start+iRun, cfg)
    labels = torch.cat((support_labels,query_labels ), dim=1).long()
    return dataset, labels

def GenerateRun(iRun, cfg, regenRState=False, generate=True):
    global _randStates, data, _min_examples
    if not regenRState:
        np.random.set_state(_randStates[iRun])

    classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
    shuffle_indices = np.arange(_min_examples)
    dataset = None
    if generate:
        dataset = torch.zeros((cfg['ways'], cfg['shot'], data.shape[2]))
        support_labels = torch.zeros(cfg['ways'], cfg['shot'], dtype=torch.int64)

    alpha = 2 * np.ones(cfg['ways'])
    query_samples = get_dirichlet_query_dist(alpha, 1, cfg['ways'], cfg['queries'] * cfg['ways'])[0]
    querySet = []
    labelSet = []
    for i in range(cfg['ways']):
        shuffle_indices = np.random.permutation(shuffle_indices)
        if generate:
            dataset[i] = data[classes[i], shuffle_indices, :][:cfg['shot']]
            support_labels[i] = i

            if query_samples[i] > data[classes[i], shuffle_indices, :].shape[0]:
                dist = query_samples[i] - data[classes[i], shuffle_indices, :].shape[0]
                query = data[classes[i], shuffle_indices, :][:query_samples[i]]
                query_extra = data[classes[i], shuffle_indices[:dist], :][:query_samples[i]]
                query = torch.cat((query, query_extra), dim=0)
            else:
                query = data[classes[i], shuffle_indices, :][:query_samples[i]]
            querySet.append(query)
            label_que = i * torch.ones(query_samples[i], dtype=torch.int64)
            labelSet.append(label_que)
    querys = torch.cat(querySet, dim=0)
    querys_labels = torch.cat(labelSet, dim=0)
    dataset = torch.cat((dataset.reshape(-1, data.shape[2]), querys), dim=0)
    support_labels = support_labels.reshape(-1)

    return dataset, support_labels, querys_labels

Not only will this be implemented to have dataset and support_labels, but it will also be used to query_labels. You also need to modify any place that implicitly uses the labels. For example, when you compute centres, we use the mean function on [0:5]. However, it is a case that implicitly uses labels, then you need to use features times labels, then summarize the result and divide by the support number. Do you implement this function by your self? it implicitly use the labels and you need put the generate labels in. def getAccuracy(self, probas): olabels = probas.argmax(dim=2) matches = labels.eq(olabels).float() acc_test = matches[:,n_lsamples:].mean(1)

    m = acc_test.mean().item()
    pm = acc_test.std().item() *1.96 / math.sqrt(n_runs)
    return m, pm
21Johnson21 commented 3 months ago
ndatas, labels = FSLTask.GenerateRunSet(cfg=cfg)  

def GenerateRunSet(start=None, end=None, cfg=None):
    global dataset, _maxRuns
    if start is None:
        start = 0
    if end is None:
        end = _maxRuns
    if cfg is None:
        cfg = {"shot": 1, "ways": 5, "queries": 15}

    setRandomStates(cfg)
    print("generating task from {} to {}".format(start, end))

    dataset = torch.zeros((end - start, cfg['ways'] * (cfg['shot'] + cfg['queries']), data.shape[2]))
    query_labels = torch.zeros((end-start, cfg['queries'] * cfg['ways']))
    support_labels = torch.zeros((end - start, cfg['shot'] * cfg['ways']))
    #get_dirichlet_query_dist(2, n_tasks, n_ways, q_shots)
    for iRun in range(end-start):
        dataset[iRun], support_labels[iRun], query_labels[iRun] = GenerateRun(start+iRun, cfg)
    labels = torch.cat((support_labels,query_labels ), dim=1).long()
    return dataset, labels

def GenerateRun(iRun, cfg, regenRState=False, generate=True):
    global _randStates, data, _min_examples
    if not regenRState:
        np.random.set_state(_randStates[iRun])

    classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
    shuffle_indices = np.arange(_min_examples)
    dataset = None
    if generate:
        dataset = torch.zeros((cfg['ways'], cfg['shot'], data.shape[2]))
        support_labels = torch.zeros(cfg['ways'], cfg['shot'], dtype=torch.int64)

    alpha = 2 * np.ones(cfg['ways'])
    query_samples = get_dirichlet_query_dist(alpha, 1, cfg['ways'], cfg['queries'] * cfg['ways'])[0]
    querySet = []
    labelSet = []
    for i in range(cfg['ways']):
        shuffle_indices = np.random.permutation(shuffle_indices)
        if generate:
            dataset[i] = data[classes[i], shuffle_indices, :][:cfg['shot']]
            support_labels[i] = i

            if query_samples[i] > data[classes[i], shuffle_indices, :].shape[0]:
                dist = query_samples[i] - data[classes[i], shuffle_indices, :].shape[0]
                query = data[classes[i], shuffle_indices, :][:query_samples[i]]
                query_extra = data[classes[i], shuffle_indices[:dist], :][:query_samples[i]]
                query = torch.cat((query, query_extra), dim=0)
            else:
                query = data[classes[i], shuffle_indices, :][:query_samples[i]]
            querySet.append(query)
            label_que = i * torch.ones(query_samples[i], dtype=torch.int64)
            labelSet.append(label_que)
    querys = torch.cat(querySet, dim=0)
    querys_labels = torch.cat(labelSet, dim=0)
    dataset = torch.cat((dataset.reshape(-1, data.shape[2]), querys), dim=0)
    support_labels = support_labels.reshape(-1)

    return dataset, support_labels, querys_labels

这不仅将实现为拥有数据集和support_labels,而且还将用于query_labels。您还需要修改隐式使用标签的任何位置。例如,当您计算中心时,我们在 [0:5] 上使用 average 函数。但是,这是一种隐式使用标签的情况,那么您需要使用功能乘以标签,然后汇总结果并除以支持数。这个功能是你自己实现的吗?它隐式使用标签,您需要将生成标签放入其中。 def getAccuracy(self, probas): olabels = probas.argmax(dim=2) matches = labels.eq(olabels).float() acc_test = matches[:,n_lsamples:].mean(1)

    m = acc_test.mean().item()
    pm = acc_test.std().item() *1.96 / math.sqrt(n_runs)
    return m, pm

Dear author, Yes, I refer to some other unbalanced settings of the code. I have made adjustments to the centralized function as per your instructions. However, I am unclear about how to integrate the generated labels into the getAccuracy function. And is there anything else needed to change about the imbalance settings? Thank you for your patience and I appreciate it

best wishes

def centerDatas(datas):
    weights = labels.unsqueeze(-1)
    weighted_data = datas * weights
    weighted_mean_center = weighted_data.sum(dim=1) / weights.sum(dim=1)
    centered_data = datas - weighted_mean_center.unsqueeze(1)

    datas = centered_data / torch.norm(centered_data, dim=2, keepdim=True)

    return datas
allenhaozhu commented 3 months ago

at lease please reimplement the function getAccuracy, it implicitly uses the labels.

21Johnson21 commented 3 months ago

Dear author, I modified the function for invoking labels implicitly, and the final result, although slightly improved near 63% (miniimagenet 5w1s), is still far from the result in your paper.

    def getAccuracy(self, probas, labels, n_lsamples, n_runs):
        olabels = probas.argmax(dim=2)
        matches = labels.eq(olabels).float()
        acc_test = matches[:, n_lsamples:].mean(1)

        m = acc_test.mean().item()
        pm = acc_test.std().item() * 1.96 / math.sqrt(n_runs)

        return m, pm
allenhaozhu commented 3 months ago

n_lsamples

It sounds weird because you can get reasonable results (not 20%) in random label queries.

21Johnson21 commented 3 months ago

Dear author, Indeed, the performance of randomly configured settings appears to be normal, which has left me confused. I am uncertain whether the unbalanced configuration necessitates any alterations beyond those found in the normal version of the code, aside from the implicit utilization of labels we previously discussed. best wishes.