tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
https://js.tensorflow.org
Apache License 2.0
18.49k stars 1.93k forks source link

CTC loss #1759

Open marsiancba opened 5 years ago

marsiancba commented 5 years ago

Please add support for CTC loss.

My use-case is recognizing handwritten numbers (1-5 digits).

fahimsun commented 5 years ago

Please add CTC support in Tensorflow JS

CTC is crucial in all cases where Input length is not equal to output length. For example: Speech Recognition using RNN , handwritten digits in sequence

WenheLI commented 5 years ago

Interested in this one, can I have a try with it?

rthadur commented 5 years ago

@WenheLI sure please submit a new Pull Request

marsiancba commented 5 years ago

@WenheLI I can provide you some rough (but working) implementation of CTC loss in tfjs, if you want something to start from...

WenheLI commented 5 years ago

@marsiancba Sure, that will be really helpful!

fahimsun commented 5 years ago

@marsiancba sure, i am interested. it would be very helpful.

marsiancba commented 5 years ago
export function ctcLoss(
    /** 
     * batch_size / max_label_seq_length 
     */
    labels: tf.Tensor2D,
    /**
     * batch_size / frames / num_labels
     */
    logits: tf.Tensor3D,
    options: {
    } = {}
): tf.Tensor1D {
    return tf.tidy(() => {
        const vec = ctcLoss_vec(labels, logits, options);

        return vec.mean().neg() as tf.Tensor1D;
    })
}

export function ctcLoss_vec(
    labels: tf.Tensor2D,
    logits: tf.Tensor3D,
    options: {
    } = {}
): tf.Tensor2D {
    function p(name: string, t?: tf.Tensor) {
        return;
        if (t)
            printTensor(name, t);
        else
            console.log(name);
    }

    return tf.tidy(() => {

        const SUPER_SMALL = -1e9; //-1e38;

        const logits_normalized = logits.sub(logits.logSumExp(2, true));

        p('labels', labels);
        p('logits', logits);
        p('logits_normalized', logits_normalized);

        const [batch_size, num_time_steps] = logits.shape;
        if (labels.shape[0] != batch_size) ASC.dieError('1694042736');
        const max_label_seq_length = labels.shape[1];
        const y0_size = max_label_seq_length * 2 + 1;
        const y_size = y0_size + 1;

        p('labels', labels);
        p('logits', logits);
        p('logits_normalized', logits_normalized);
        p('max_label_seq_length=' + max_label_seq_length);
        p('y0_size=' + y0_size);

        const labels_buff = labels.bufferSync();

        const y_loc_buff = tf.buffer([batch_size, y0_size, 2], 'int32');
        const res_loc_buff = tf.buffer([batch_size, 2], 'int32');

        for (let b = 0; b < batch_size; b++) {
            for (let y = 0; y < y0_size; y++)
                y_loc_buff.set(b, b, y, 0);
            let len = max_label_seq_length;
            for (let t = 0; t < max_label_seq_length; t++) {
                const l = labels_buff.get(b, t);
                if (l == 0) {
                    len = t;
                    break;
                }
                const y = t * 2 + 1;
                y_loc_buff.set(l, b, y, 1);
            }

            res_loc_buff.set(b, b, 0);
            res_loc_buff.set(2 * len, b, 1);
        }

        const incoming_loc_buff = tf.buffer([batch_size, y_size, 3, 2], 'int32');
        for (let b = 0; b < batch_size; b++) {
            for (let y = 0; y <= y0_size; y++) {
                for (let i = 0; i < 3; i++)
                    incoming_loc_buff.set(b, b, y, i, 0);
            }
            for (let y = 0; y < y0_size; y++) {
                incoming_loc_buff.set(y, b, y, 0, 1);
                incoming_loc_buff.set(y > 0 ? y - 1 : y0_size, b, y, 1, 1);
                let moze_double = false;
                if (y > 2) {
                    if (y_loc_buff.get(b, y, 1) != y_loc_buff.get(b, y - 2, 1))
                        moze_double = true;
                }
                incoming_loc_buff.set(moze_double ? y - 2 : y0_size, b, y, 2, 1);
            }
            incoming_loc_buff.set(y0_size, b, y0_size, 0, 1);
            incoming_loc_buff.set(y0_size, b, y0_size, 1, 1);
            incoming_loc_buff.set(y0_size, b, y0_size, 2, 1);
        }

        const y_loc = y_loc_buff.toTensor();
        const res_loc = res_loc_buff.toTensor();
        const incoming_loc = incoming_loc_buff.toTensor();
        p('y_loc', y_loc);
        p('res_loc', res_loc);
        p('incoming_loc', incoming_loc);

        const y0 = gatherND(
            logits_normalized.transpose([0, 2, 1]),
            y_loc,
        )
            .transpose([0, 2, 1]);

        const y = y0.pad([[0, 0], [0, 0], [0, 1]], SUPER_SMALL);

        p('y', y);

        let log_alpha =
            tf.scalar(0).reshape([1, 1]).tile([batch_size, 1])
                .concat(
                    tf.scalar(SUPER_SMALL).reshape([1, 1]).tile([batch_size, y0_size]),
                    1
                );

        function shift(t: tf.Tensor) {
            return t.pad([[0, 0], [1, 0]], SUPER_SMALL).slice([0, 0], t.shape);
        }
        function logSumExp(a: tf.Tensor, b: tf.Tensor) {
            return a.expandDims(2).concat(b.expandDims(2), 2).logSumExp(2);
        }

        const t2y = y.unstack(1);
        for (let t = 0; t < num_time_steps; t++) {
            p("Time: " + t);

            const ty = t2y[t];
            p('log_alpha', log_alpha);
            p('ty', ty);

            const incoming = gatherND(log_alpha, incoming_loc);
            p('incoming', incoming);

            const incoming_plus_ty = incoming.add(ty.expandDims(2));
            p('incoming_plus_ty', incoming_plus_ty);

            const new_log_alpha2 = incoming_plus_ty.logSumExp(2);
            p('new_log_alpha2', new_log_alpha2);

            log_alpha = new_log_alpha2;
        }

        const log_alpha_final = logSumExp(log_alpha, shift(log_alpha));
        p('log_alpha_final', log_alpha_final);

        const vec = gatherND(log_alpha_final, res_loc);
        //printTensor('vec', vec);

        return vec as tf.Tensor2D;
    })
}

function gatherND(x: tf.Tensor, indices: tf.Tensor): tf.Tensor {
    const grad = (dy: tf.Tensor, saved: tf.Tensor[]) => {
        return { x: () => tf.scatterND(saved[0], dy, x.shape) }
    }
    return ENGINE.runKernel(
        (backend, save) => {
            save([indices]);
            return backend.gatherND(x, indices);
        },
        { x },
        grad
    ) as
        tf.Tensor;
}
fahimsun commented 5 years ago

@WenheLI are you working on this?

WenheLI commented 5 years ago

@fahimsun Yep, but I am in the break right now, it may take some time to implement. If you want to take it over, it is fine as well.

jasonmayes commented 2 years ago

Update: This is blocking new folks work from being published around handwriting recognition and speech recognition. Is there an update for this?

Please see: https://discuss.tensorflow.org/t/ctc-loss-implementation-in-tfjs/6645/5

harangp commented 2 years ago

Hi, I've prepared a working CTC loss / gradient calculator for TFJS. Works on batches, handles various length labels for learning, built entirely from scratch based on the original papers. Plugabble into TFJS's model to calculate losses and gradients during model.fit() Open sourced here: https://github.com/harangp/tfjsctcloss Drop me a line if you are interested.

gaikwadrahul8 commented 1 year ago

Hi, @marsiancba

Apologize for the delayed response and It seems like at the moment we haven't implemented this feature request so still are you looking this feature ?

@harangp, Would you like to contribute for this feature ? If yes please refer this link, Thank you!

marsiancba commented 1 year ago

Hi, @gaikwadrahul8 The part of our project where we used tfjs is stalled atm, but we plan to return to it somewhere in the future. We used our own implementation (https://github.com/tensorflow/tfjs/issues/1759#issuecomment-534565149) and it seemed to work ok, but native implementation will probably be better/faster.

harangp commented 1 year ago

@harangp, Would you like to contribute for this feature ? If yes please refer this link, Thank you!

Yes, in time. I still have to work on the backward and collection part to be native, and I'm not confident enough about the returned gradients. I also want to utilize masks to make things more efficient. And lastly, there's a problem with not being a drop-in compatible with the python version, so I'm not sure it would fit into the TFJS concept.