toshas / torch-householder

Efficient Householder Transformation in PyTorch
Other
58 stars 2 forks source link

Slow performance compared to `torch.linalg.householder_product` in forward pass #7

Open bytesnake opened 2 years ago

bytesnake commented 2 years ago

Problem

I'm using an orthonormal constrained matrix in a learnable filterbank setting. Now I want to optimize the training and run some profiling with torch, but getting strange results. Just want to double-check here whether I'm doing something wrong.

Code

I'm constructing the matrix during forward pass like this:

def __init__(self, ..):
       [..]

        # householder decomposition
        decomp, tau = torch.geqrf(filters)

        # assume that DCT is orthogonal
        filters = decomp.tril(diagonal=-1) + torch.eye(decomp.shape[0], decomp.shape[1])

        # register everything as parameter and set gradient flags
        self.filters = torch.nn.Parameter(filters, requires_grad=True)
        self.register_parameter('filter_q', self.filters)

def filters(self):
        valid_coeffs = self.filters.tril(diagonal=-1)
        tau = 2. / (1 + torch.norm(valid_coeffs, dim=0) ** 2)
        return torch.linalg.householder_product(valid_coeffs, tau)
        #return torch_householder_orgqr(valid_coeffs, tau)

Profiles

All profiles are created with the pytorch profiler with warmup of one and two trial runs:

Profile torch householder_product (matrix 512x512 f32)

Marked forward pass and backward pass visible in light green:

image

Profile torch-householder (matrix 512x512 f32)

image

Questions

I'm not an expert in torch and do not follow the development closely. There is an issue https://github.com/pytorch/pytorch/issues/50104 for integrating CUDA support to orgqr, may this cause the difference in time?

I'm also happy to share the traces with you, please just ping then :)

toshas commented 2 years ago

Thanks for this analysis, a couple of things to double check:

bytesnake commented 2 years ago

This will give you the estimate of runtime speed and memory consumption ratios at the time I checked last

the evaluation uses "thin" matrices, but in my case all are quadratic for perfect reconstruction. This means that for m=32 we would need to go to d=8192, which is eight-time off the performance comparison chart

Are you performing warmup in both cases?

yes I'm doing warmup of one to launch all CUDA kernels etc., but will try to increase

EDIT: tried that, but there is no difference for warmup=4

bytesnake commented 2 years ago

probably my expectations are too high and I need to reparametrize the model, but still want to double-check that I'm using householder correctly

toshas commented 2 years ago

Can you please run https://github.com/toshas/torch-householder/blob/master/tests/benchmark_one.py with the values of width (r), height (d), and batchsize (b) of interest?

bytesnake commented 2 years ago
$ python3 benchmark_one.py --repeats 5 --method hh_ours --b 1 --d 512 --r 512
70.72763475589454 1.2755393754559918e-06
$ python3 benchmark_one.py --repeats 5 --method hh_pt --b 1 --d 512 --r 512
318.1826988235116 1.3709068298339844e-06

to be really comparable, I have to modify the test. It creates B decompositions, but in my case a single one is shared for the whole batch

Now with single orthonormal matrix and B batches of data (I believe that the decomp. works fine ;)):

$ python3 benchmark_one.py --repeats 20 --method hh_ours --b 20 --d 512 --r 512
19.521982199512422
$ python3 benchmark_one.py --repeats 20 --method hh_pt --b 20 --d 512 --r 512
2.61343854945153

still not really comparable ..

patch.txt

toshas commented 2 years ago

I applied the patch and not sure what the modified benchmark is supposed to verify. The way it was implemented, it timed pure calls to the functions. After patching there is an extra inp tensor created, which is then directly multiplied by the param tensor. Meanwhile, the transformation result out is never used.