google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.1k stars 816 forks source link

How to perform linear regression using 1 feature #1329

Closed aycandv closed 3 years ago

aycandv commented 3 years ago

Description

I want to use trax to solve a linear regression problem. (i.e. Celcius to Fahrenheit conversion).

Since training and evaluation streams are prepared by processing words in documentation, I am confused how to used only numbers instead of words. ...

Environment information

OS: Linux

$ pip freeze | grep trax
trax==1.3.7

$ pip freeze | grep tensor
mesh-tensorflow==0.1.18
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
tensorboardcolab==0.0.22
tensorflow==2.4.0
tensorflow-addons==0.8.3
tensorflow-datasets==4.0.1
tensorflow-estimator==2.4.0
tensorflow-gcs-config==2.4.0
tensorflow-hub==0.10.0
tensorflow-metadata==0.26.0
tensorflow-privacy==0.2.2
tensorflow-probability==0.11.0
tensorflow-text==2.4.1

$ pip freeze | grep jax
jax==0.2.7
jaxlib==0.1.57+cuda101

$ python -V
Python 3.6.9

For bugs: reproduction and error logs

# Steps to reproduce:

def generate_samples():
    data = [(celcius[i], fahrenheit[i]) for i in range(len(celcius))]
    for i in range(100):
        yield random.choice(data)

train_stream = generate_samples()
eval_stream = generate_samples()

from trax.supervised import training
from trax.fastmath import numpy as jnp
from trax.layers.base import Fn
def SizeReportL2Loss():
  """Returns a layer that computes total L2 loss for one batch."""
  def f(model_output, targets, weights): 
    """Returns elementwise-weighted L2 norm of `model_output - targets`.

    Args:
      model_output: Output from one batch, treated as an unanalyzed tensor.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`.
    """
    print("Model", model_output.shape)
    print("Targets", targets.shape)
    print("Weights", weights.shape)
    trax.shapes.assert_same_shape(model_output, targets)
    trax.shapes.assert_same_shape(targets, weights)
    l2 = weights * (model_output - targets)**2
    return jnp.sum(l2) / jnp.sum(weights)
  return Fn('L2Loss', f)

# Training task.
train_task = training.TrainTask(
    labeled_data=train_stream,
#    loss_layer=tl.CrossEntropyLoss(),
    loss_layer=tl.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam(0.01),
#    optimizer=trax.optimizers.RMSProp(),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_stream,
    metrics=[SizeReportL2Loss(),],
#     metrics=[tl.L2Loss(), tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
import os
output_dir = os.path.expanduser('~/output_dir/')
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)
# Run 2000 steps (batches).
training_loop.run(8)
# Error logs:
---------------------------------------------------------------------------
LayerError                                Traceback (most recent call last)
<ipython-input-70-f50ba089649f> in <module>()
     48                               train_task,
     49                               eval_tasks=[eval_task],
---> 50                               output_dir=output_dir)
     51 # Run 2000 steps (batches).
     52 training_loop.run(8)

1 frames
/usr/local/lib/python3.6/dist-packages/trax/supervised/training.py in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, callbacks)
    212     if not use_memory_efficient_trainer:
    213       if _is_uninitialized(self._model):
--> 214         self._model.init(self._batch_signature)
    215       self._eval_model.rng = self.new_rng()
    216       if _is_uninitialized(self._eval_model):

/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in init(self, input_signature, rng, use_cache)
    305       name, trace = self._name, _short_traceback(skip=3)
    306       raise LayerError(name, 'init', self._caller,
--> 307                        input_signature, trace) from None
    308 
    309   def init_from_file(self, file_name, weights_only=False, input_signature=None):

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-61-73a191646e65>, line 4
  layer input shapes: (ShapeDtype{shape:(), dtype:float32}, ShapeDtype{shape:(), dtype:float32})

  File [...]/trax/layers/combinators.py, line 106, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer Embedding_100_256 (in _forward_abstract):
  layer created in file [...]/<ipython-input-61-73a191646e65>, line 2
  layer input shapes: ShapeDtype{shape:(), dtype:float32}

  File [...]/jax/interpreters/partial_eval.py, line 416, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)

  File [...]/jax/interpreters/partial_eval.py, line 1201, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/partial_eval.py, line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/dist-packages/jax/linear_util.py, line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/dist-packages/jax/linear_util.py, line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer Embedding_100_256 (in pure_fn):
  layer created in file [...]/<ipython-input-61-73a191646e65>, line 2
  layer input shapes: ShapeDtype{shape:(), dtype:float32}

  File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
    y = forward(self, x, *args, **kwargs)

  File [...]/trax/layers/core.py, line 181, in forward
    embedded = jnp.take(self.weights, x, axis=0)

  File [...]/_src/numpy/lax_numpy.py, line 3936, in take
    slice_sizes=tuple(slice_sizes))

  File [...]/_src/lax/lax.py, line 867, in gather
    slice_sizes=canonicalize_shape(slice_sizes))

  File [...]/dist-packages/jax/core.py, line 271, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1073, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 2011, in standard_abstract_eval
    shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)

  File [...]/_src/lax/lax.py, line 4055, in _gather_dtype_rule
    raise ValueError("start_indices must have an integer type")

ValueError: start_indices must have an integer type
thoo commented 3 years ago

I am planning to write some tutorials but didn't have much time. Here is for a simple linear and logistic regression I had written. https://github.com/thoo/trax-tutorial/blob/master/basic_regression_tensorboard.ipynb

aycandv commented 3 years ago

Thanks for your help

mkaramib commented 3 years ago

The output format should be JAX or numpy array in the data-generator.