ottokart / punctuator2

A bidirectional recurrent neural network model with attention mechanism for restoring missing punctuation in unsegmented text
http://bark.phon.ioc.ee/punctuator
MIT License
657 stars 195 forks source link

Multiple GPU Support #60

Open sauravjoshi opened 4 years ago

sauravjoshi commented 4 years ago

Hi There,

Nicely drafted and built project. I wanted to know, if the training support multiple GPUs?
I tried looking into models.py and found that there was not any targets specified as such for theano.shared.? I may be lagging somewhere, as I'm new to theano. If yes is there any specific config that needs to be provided as such?

Thanks, Saurav

bharat-patidar commented 4 years ago

Yes, I too have the same question.

@ottokart Any help is appreciated!

ottokart commented 4 years ago

Hi!

punctuator currently does not support multiple GPUs. I've managed to train all my models in a few days, which has been acceptable for me.

These functions in models.py have to be modified to accept the target parameter:

def weights_const(i, o, name, const, keepdims=False):
    W_values = np.ones(_get_shape(i, o, keepdims)).astype(theano.config.floatX) * const
    return theano.shared(value=W_values, name=name, borrow=True)

def weights_identity(i, o, name, const, keepdims=False):
    #"A Simple Way to Initialize Recurrent Networks of Rectified Linear Units" (2015) (http://arxiv.org/abs/1504.00941)
    W_values = np.eye(*_get_shape(i, o, keepdims)).astype(theano.config.floatX) * const
    return theano.shared(value=W_values, name=name, borrow=True)

def weights_Glorot(i, o, name, rng, is_logistic_sigmoid=False, keepdims=False):
    #http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
    d = np.sqrt(6. / (i + o))
    if is_logistic_sigmoid:
        d *= 4.
    W_values = rng.uniform(low=-d, high=d, size=_get_shape(i, o, keepdims)).astype(theano.config.floatX)
    return theano.shared(value=W_values, name=name, borrow=True)