zuoxingdong / lagom

lagom: A PyTorch infrastructure for rapid prototyping of reinforcement learning algorithms.
MIT License
373 stars 30 forks source link

utils #150

Closed zuoxingdong closed 5 years ago

zuoxingdong commented 5 years ago
def tensorify(x, device):
    if torch.is_tensor(x):
        if str(x.device) != str(device):
            x = x.to(device)
        return x
    elif isinstance(x, np.ndarray):
        return torch.from_numpy(x).float().to(device)
    else:
        return torch.from_numpy(np.asarray(x)).float().to(device)

def numpify(x, dtype):
    if torch.is_tensor(x):
        return x.detach().cpu().numpy().astype(dtype)
    elif isinstance(x, np.ndarray):
        return x.astype(dtype)
    else:
        return np.asarray(x, dtype=dtype)