scikit-hep / pyhf

pure-Python HistFactory implementation with tensors and autodiff
https://pyhf.readthedocs.io/
Apache License 2.0
281 stars 83 forks source link

pytorch TPU XLA support #1248

Open lukasheinrich opened 3 years ago

lukasheinrich commented 3 years ago

initial experiments show that modulo some smalllish fixes pytorch XLA could work

import pyhf
import torch
import torch_xla
import torch_xla.core.xla_model as xm

spec = {
    'channels': [
        {
            'name': 'singlechannel',
            'samples': [
                {
                    'name': 'signal',
                    'data': [5],
                    'modifiers': [
                        {'name': 'mu', 'type': 'normfactor', 'data': None}
                    ],
                },
                {
                    'name': 'background',
                    'data': [50],
                    'modifiers': []
                },
            ],
        }
    ]
}
m = pyhf.Model(spec)

image

Though it's unclear whether having multiple high-level tensor libs (jax, pytorch) that can target XLA iss benefiical (though it's the same w/ GPU)

This is the XLA graph

pt

matthewfeickert commented 3 years ago

initial experiments show that modulo some smalllish fixes pytorch XLA could work

This is great to see @lukasheinrich. Can you comment on how this fits in with Issue #1244?

matthewfeickert commented 3 years ago

Ah, I now see after looking at the docs and the pytorch/xla repo that it doesn't seem trivial to install like torch is and is probably best used as a Docker image.

lukasheinrich commented 3 years ago

I think it might be fairly related. you just pass a different device to the tensors. The CLs computation might be more tricky sincce XLA has a static graph so nont sure how it dealls with conditionals

lukasheinrich commented 3 years ago

here's the diff https://github.com/scikit-hep/pyhf/pull/1249/files

matthewfeickert commented 3 years ago

I think it might be fairly related. you just pass a different device to the tensors. The CLs computation might be more tricky sincce XLA has a static graph so nont sure how it dealls with conditionals

Nice. Yeah looking at the diff it seems pretty reasonable so far.