google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.09k stars 816 forks source link

Issue using `L2Loss` for 2-D regression problem :bug: #980

Open mapadofu opened 4 years ago

mapadofu commented 4 years ago

Description

The ranks of data batches are different between the training and evaluation tasks.

In the output below note the following: during the training task, the (model, target, weights) arrays given to the L2Loss object have shapes:

Model (256, 2)
Targets (256, 2)
Weights (256, 2)

(note, I'm trying to do a 2-D regression).

However, during the evaluation task the (model, target, weights) given to the L2Loss object have shapes

Model (4, 2)
Targets (4,)
Weights (4,)

which breaks the evaluator. What I'd expect is for all of these to be of size (4,2) (or more generally (N,2)) i.e. the same behavior as in the train task.

Environment information


Linux version 4.15.0-112-generic (buildd@lcy01-amd64-027) 
(gcc version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04)) 
#113-Ubuntu SMP Thu Jul 9 23:41:39 UTC 2020

$ pip freeze | grep trax
trax==1.3.4

$ pip freeze | grep tensor
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.0
tensorflow-addons==0.11.2
tensorflow-datasets==3.2.1
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.8.0
tensorflow-metadata==0.23.0
tensorflow-probability==0.7.0
tensorflow-text==2.3.0

$ pip freeze | grep jax
jax==0.1.75
jaxlib==0.1.52

$ python -V
Python 3.6.9

For bugs: reproduction and error logs

import os
import trax
from trax import layers as tl
from trax.supervised import training
import numpy
import random

#train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
#eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()

def generate_samples():
    # (text, lat/lon)
    data= [
        ("Aberdeen MS",  numpy.array((33.824742, -88.554591)) ),
        ("Aberdeen SD", numpy.array((45.463186, -98.471033))),
        ("Aberdeen WA", numpy.array((46.976432, -123.795781))),
        ("Amite City LA", numpy.array((30.733723, -90.5208))),
        ("Amory MS", numpy.array((33.984789, -88.48001))),
        ("Amouli AS", numpy.array((-14.26556, -170.589772))),
        ("Amsterdam NY", numpy.array((42.953149, -74.19505)))
    ]
    #data= [
    #    ("Aberdeen MS",  numpy.array([1.0,])),
    #    ("Aberdeen SD", numpy.array([0.0,])),
    #("Aberdeen WA", numpy.array([0.0,])),
    #    ("Amite City LA", numpy.array([0.0,])),
    #    ("Amory MS", numpy.array([1.0,])),
    #    ("Amouli AS", numpy.array([0.0,])),
    #    ("Amsterdam NY", numpy.array([0.0,]))
    #]
    for i in range(1024*8):
        yield random.choice(data)

train_stream = generate_samples()
eval_stream = generate_samples()

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Regress to lat/lon
#    tl.LogSoftmax()   # Produce log-probabilities.
)

# You can print model structure.
print(model)

print(next(train_stream))  # See one example.

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
#    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[   8, 128,],
                             batch_sizes=[256,   64, 4],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )

train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.:wq

example_batch = next(eval_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.:wq

from trax.fastmath import numpy as jnp
from trax.layers.base import Fn
def SizeReportL2Loss():
  """Returns a layer that computes total L2 loss for one batch."""
  def f(model_output, targets, weights):  # pylint: disable=invalid-name
    """Returns elementwise-weighted L2 norm of `model_output - targets`.

    Args:
      model_output: Output from one batch, treated as an unanalyzed tensor.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`.
    """
    print("Model", model_output.shape)
    print("Targets", targets.shape)
    print("Weights", weights.shape)
    trax.shapes.assert_same_shape(model_output, targets)
    trax.shapes.assert_same_shape(targets, weights)
    l2 = weights * (model_output - targets)**2
    return jnp.sum(l2) / jnp.sum(weights)
  return Fn('L2Loss', f)

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
#    loss_layer=tl.CrossEntropyLoss(),
    loss_layer=SizeReportL2Loss(),
    optimizer=trax.optimizers.Adam(0.01),
#    optimizer=trax.optimizers.RMSProp(),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[SizeReportL2Loss(),],
#     metrics=[tl.L2Loss(), tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)
# Run 2000 steps (batches).
training_loop.run(8)

Error logs:

2020-08-27 21:21:08.059843: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
Serial[
  Embedding_8192_256
  Mean
  Dense_2
]
('Aberdeen WA', array([  46.976432, -123.795781]))
2020-08-27 21:21:10.585194: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "Not found: Could not locate the credentials file.". Retrieving token from GCE failed with "Failed precondition: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Couldn't resolve host 'metadata'".
shapes = [(256, 8), (256, 2), (256, 2)]
shapes = [(256, 8), (256, 2), (256, 2)]
/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Model (256, 2)
Targets (256, 2)
Weights (256, 2)
Model (4, 2)
Targets (4,)
Weights (4,)
Traceback (most recent call last):
  File "trax04.py", line 119, in <module>
    output_dir=output_dir)
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/supervised/training.py", line 187, in __init__
    self.load_checkpoint()
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/supervised/training.py", line 588, in load_checkpoint
    self._model_in_training.init_from_file(path)
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/layers/base.py", line 311, in init_from_file
    weights_and_state_sig = self.weights_and_state_signature(input_signature)
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/layers/base.py", line 443, in weights_and_state_signature
    return abstract_init(input_signature)
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/fastmath/jax.py", line 310, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/api.py", line 1753, in eval_shape
    *map(abstractify, args_flat))
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 324, in abstract_eval_fun
    instantiate=True, stage_out=True)
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 423, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/dmcnamara/tfhub-env/lib/python3.6/site-packages/trax/layers/base.py", line 288, in init
    input_signature, trace) from None
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/trax/supervised/training.py, line 146
  layer input shapes: (Traced<ShapedArray(int32[4,1024]):JaxprTrace(level=0/0)>, Traced<ShapedArray(int32[4]):JaxprTrace(level=0/0)>, Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)>)

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer L2Loss (in _forward_abstract):
  layer created in file [...]/trax/layers/base.py, line 704
  layer input shapes: (ShapeDtype{shape:(4, 2), dtype:float32}, Traced<ShapedArray(int32[4]):JaxprTrace(level=0/0)>, Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)>)

  File [...]/jax/interpreters/partial_eval.py, line 324, in abstract_eval_fun
    instantiate=True, stage_out=True)

  File [...]/jax/interpreters/partial_eval.py, line 423, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer L2Loss (in pure_fn):
  layer created in file [...]/trax/layers/base.py, line 704
  layer input shapes: (ShapeDtype{shape:(4, 2), dtype:float32}, ShapeDtype{shape:(4,), dtype:int32}, ShapeDtype{shape:(4,), dtype:float32})

  File [...]/trax/layers/base.py, line 658, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 700, in _forward
    return f(*xs)

  File [...]/trax04.py, line 88, in f
    trax.shapes.assert_same_shape(model_output, targets)

  File [...]/site-packages/trax/shapes.py, line 138, in assert_same_shape
    assert_shape_equals(array1, array2.shape)

  File [...]/site-packages/trax/shapes.py, line 132, in assert_shape_equals
    'Invalid shape {}; expected {}.'.format(array.shape, shape)

AssertionError: Invalid shape (4, 2); expected (4,).
mapadofu commented 4 years ago

I can't figure out how to tag this as a bug