Closed jiaqima closed 1 year ago
Refactor dataloader.py and LoadModel.py to accommodate the new datasets and models prepared by Caiwei.
Running the following code will yield an accuracy with 1.0. Changing to data_name='lending-club' will yield an accuracy with 0.815.
data_name='lending-club'
data_name = 'rcdv' import torch from openxai.dataloader import return_loaders loader_train, loader_test = return_loaders(data_name=data_name, download=True) inputs, labels = iter(loader_test).next() from openxai import LoadModel model = LoadModel(data_name=data_name, ml_model='ann', pretrained=True) prob = model(inputs.to(dtype=torch.float)) pred = torch.argmax(prob, dim=1) accuracy = sum(pred == labels) / len(pred)
Description
Refactor dataloader.py and LoadModel.py to accommodate the new datasets and models prepared by Caiwei.
Test
Running the following code will yield an accuracy with 1.0. Changing to
data_name='lending-club'
will yield an accuracy with 0.815.