dmlc / gluon-nlp

NLP made easy
https://nlp.gluon.ai/
Apache License 2.0
2.55k stars 538 forks source link

Making CandidateSamplers Blocks #385

Open leezu opened 5 years ago

leezu commented 5 years ago

Currently, nlp.data.CandidateSampler are not Blocks but are completely based on imperative API. This brings some overhead. UnigramCandidateSampler (the so far only supported CandidateSampler) can easily be changed to a HybridBlock.

Due to lack of support for specifying the output shape of F.random.uniform via NDarray (/Symbol) and lack of F.random.uniform_like, changing to HybridBlock would require fixing the shape when constructing the Sampler.

Let's discuss if that is reasonable or if / when we can work around it.

modified   gluonnlp/data/candidate_sampler.py
@@ -20,10 +20,13 @@

 __all__ = ['CandidateSampler', 'UnigramCandidateSampler']

+import functools
+import operator
+
 import mxnet as mx

-class CandidateSampler(object):
+class CandidateSampler(mx.gluon.Block):
     """Abstract Candidate Sampler

     After initializing one of the concrete candidate sample implementations,
@@ -35,7 +38,7 @@ class CandidateSampler(object):
         raise NotImplementedError

-class UnigramCandidateSampler(CandidateSampler):
+class UnigramCandidateSampler(mx.gluon.HybridBlock):
     """Unigram Candidate Sampler

     Draw random samples from a unigram distribution with specified weights
@@ -46,22 +49,27 @@ class UnigramCandidateSampler(CandidateSampler):
     weights : mx.nd.NDArray
         Unnormalized class probabilities. Samples are drawn and returned on the
         same context as weights.context.
+    shape : int or tuple of int
+        Shape of data to be sampled.

     """

-    def __init__(self, weights):
+    def __init__(self, weights, shape):
+        super(UnigramCandidateSampler, self).__init__()
+        self._shape = shape
+
         self._context = weights.context
         self.N = weights.size
         total_weights = weights.sum()
-        self.prob = (weights * self.N / total_weights).asnumpy().tolist()
-        self.alias = [0] * self.N
+        prob = (weights * self.N / total_weights).asnumpy().tolist()
+        alias = [0] * self.N

         # sort the data into the outcomes with probabilities
         # that are high and low than 1/N.
         low = []
         high = []
         for i in range(self.N):
-            if self.prob[i] < 1.0:
+            if prob[i] < 1.0:
                 low.append(i)
             else:
                 high.append(i)
@@ -71,23 +79,25 @@ class UnigramCandidateSampler(CandidateSampler):
             l = low.pop()
             h = high.pop()

-            self.alias[l] = h
-            self.prob[h] = self.prob[h] - (1.0 - self.prob[l])
+            alias[l] = h
+            prob[h] = prob[h] - (1.0 - prob[l])

-            if self.prob[h] < 1.0:
+            if prob[h] < 1.0:
                 low.append(h)
             else:
                 high.append(h)

         for i in low + high:
-            self.prob[i] = 1
-            self.alias[i] = i
+            prob[i] = 1
+            alias[i] = i

-        # convert to ndarrays
-        self.prob = mx.nd.array(self.prob, ctx=self._context)
-        self.alias = mx.nd.array(self.alias, ctx=self._context)
+        # store
+        prob = mx.nd.array(prob)
+        alias = mx.nd.array(alias)
+        self.prob = self.params.get_constant('prob', prob)
+        self.alias = self.params.get_constant('alias', alias)

-    def __call__(self, shape):
+    def hybrid_forward(self, F, ctx_selector, prob, alias):
         """Draw samples from uniform distribution and return sampled candidates.

         Parameters
@@ -100,14 +110,18 @@ class UnigramCandidateSampler(CandidateSampler):
         samples: NDArray
             The sampled candidate classes.
         """
-        idx = mx.nd.random.uniform(low=0, high=self.N, shape=shape,
-                                   ctx=self._context,
-                                   dtype='float64').floor().astype('float32')
-        prob = self.prob[idx]
-        alias = self.alias[idx]
-        where = mx.nd.random.uniform(shape=shape, ctx=self._context) < prob
+        flat_shape = functools.reduce(operator.mul, self._shape)
+        idx = F.random.uniform(
+            low=0,
+            high=self.N,
+            shape=flat_shape,
+            ctx=self._context,
+            dtype='float64').floor().astype('float32')
+        prob = F.gather_nd(prob, idx.reshape((1, -1)))
+        alias = F.gather_nd(alias, idx.reshape((1, -1)))
+        where = F.random.uniform(shape=flat_shape, ctx=self._context) < prob
         hit = idx * where
         alt = alias * (1 - where)
-        candidates = hit + alt
+        candidates = (hit + alt).reshape(self._shape)
szha commented 5 years ago

Due to the delay of release in upstream we will likely not get to this until after 0.5.0.