calico / baskerville

Machine learning methods for DNA sequence analysis.
Apache License 2.0
32 stars 8 forks source link

Clarification of poisson_multinomial (possibly returns wrong shape?) #38

Closed tdsone closed 4 months ago

tdsone commented 4 months ago

Hey there,

thanks a lot for the repo - it's super informative and has been a great help for me in understanding and writing my own code!

While doing a PyTorch implementation of poisson_multinomial I noticed the following: poisson_multinomial returns shape [2,2,10] for shape [2, 200, 10] as input.

My interpretation of the input shape is: batch size 2, sequence length 200, tracks 10. As far as I understand, the poisson multinomial loss has one loss per track which would result in a shape of [2, 1, 10] (or some shape with 20 losses). Instead it returns losses of shape [2,2,10]. I'm sure I'm missing something here, but why is there twice as many losses as I expect?

Here's a script to reproduce the behavior (mostly copied from this repo except for everything after the loss function):

import tensorflow as tf
import torch
import numpy as np

def poisson(yt, yp, epsilon: float = 1e-7):
    """Poisson loss, without mean reduction."""
    return yp - yt * tf.math.log(yp + epsilon)

def poisson_multinomial(
    y_true,
    y_pred,
    total_weight: float = 1,
    weight_range: float = 1,
    weight_exp: int = 4,
    epsilon: float = 1e-7,
    rescale: bool = False,
):
    """Possion decomposition with multinomial specificity term.

    Args:
        total_weight (float): Weight of the Poisson total term.
        epsilon (float): Added small value to avoid log(0).
        rescale (bool): Rescale loss after re-weighting.
    """
    seq_len = y_true.shape[1]
    pos_start = -(seq_len / 2 - 0.5)
    pos_end = seq_len / 2 + 0.5
    sigma = -pos_start / (np.log(weight_range)) ** (1 / weight_exp)

    positions = tf.range(pos_start, pos_end, dtype=tf.float32)
    position_weights = tf.exp(-((positions / sigma) ** weight_exp))
    position_weights /= tf.reduce_max(position_weights)
    position_weights = tf.expand_dims(position_weights, axis=0)
    position_weights = tf.expand_dims(position_weights, axis=-1)

    y_true = tf.math.multiply(y_true, position_weights)
    y_pred = tf.math.multiply(y_pred, position_weights)

    # sum across lengths
    s_true = tf.math.reduce_sum(y_true, axis=-2, keepdims=True)
    s_pred = tf.math.reduce_sum(y_pred, axis=-2, keepdims=True)

    # total count poisson loss, mean across targets
    poisson_term = poisson(s_true, s_pred)  # B x T
    poisson_term /= tf.reduce_sum(position_weights)

    # add epsilon to protect against tiny values
    y_true += epsilon
    y_pred += epsilon

    # normalize to sum to one
    p_pred = y_pred / s_pred

    # multinomial loss
    pl_pred = tf.math.log(p_pred)  # B x L x T
    multinomial_dot = -tf.math.multiply(y_true, pl_pred)  # B x L x T
    multinomial_term = tf.math.reduce_sum(multinomial_dot, axis=-2)  # B x T
    multinomial_term /= tf.reduce_sum(position_weights)

    # normalize to scale of 1:1 term ratio
    loss_raw = multinomial_term + total_weight * poisson_term
    if rescale:
        loss_rescale = loss_raw * 2 / (1 + total_weight)
    else:
        loss_rescale = loss_raw

    return loss_rescale

torch.manual_seed(42)
np.random.seed(42)
bs, tracks, seq_len = 2, 10, 200
y_true = torch.randint(low=0, high=1000, size=(bs, seq_len, tracks)).float()
y_pred = torch.randint(low=0, high=1000, size=(bs, seq_len, tracks)).float()

loss_tf = poisson_multinomial(y_true=y_true, y_pred=y_pred, weight_range=1.1)

print(f"{loss_tf.shape=}")

Thanks a lot for your help!

Best Timon

davek44 commented 4 months ago

Sorry, this is code under active development. I pushed a modified version that should clarify the shapes.

tdsone commented 4 months ago

Thanks for the quick reply! Makes sense now.