hyperdimensional-computing / torchhd

Torchhd is a Python library for Hyperdimensional Computing and Vector Symbolic Architectures
MIT License
221 stars 23 forks source link

Using add_online for Centroid model #164

Closed OMGAmici closed 5 months ago

OMGAmici commented 5 months ago

Is the correct way to use the add_online() method as simple as using it in place of the add() method?

So far, I have not been able to see any performance improvement / change by introducing it, which leads me to suspect I might be implementing it incorrectly or there might be a bug.

I've tried using this as in the example snippet below as well as inserting it after the first measure of the accuracy. But neither has any effect on the outcome.

with torch.no_grad():
    for samples, labels in tqdm(train_ld, desc="Training"):
        samples = samples.to(device)
        labels = labels.to(device)
        samples_hv = encode(samples).to(device)
        model.add_online(samples_hv, labels)

accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes).to(device)
mikeheddes commented 5 months ago

Hi, thank you for opening this issue. Yes, it should be as simple as that. Some things to note: 1) the improvement in accuracy will dependent on the dataset you are using, and 2) online learning is typically used in multiple iterations (epochs) to further boost the accuracy.

I modified the examples/voicehd.py to use add_online and got the following results:

add: 85.119% add_online 1 epoch: 86.273% add_online 2 epochs: 83.515% add_online 3 epochs: 87.877%

I modified examples/voicehd.py as follows:

for _ in range(3):

    with torch.no_grad():
        for samples, labels in tqdm(train_ld, desc="Training"):
            samples = samples.to(device)
            labels = labels.to(device)

            samples_hv = encode(samples)
            model.add_online(samples_hv, labels)

    accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)

    with torch.no_grad():
        for samples, labels in tqdm(test_ld, desc="Testing"):
            samples = samples.to(device)

            samples_hv = encode(samples)
            outputs = model(samples_hv, dot=False)
            accuracy.update(outputs.cpu(), labels)

    print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")
OMGAmici commented 5 months ago

Thank you! That makes sense.