Oneflow-Inc / oneflow

OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient.
http://www.oneflow.org
Apache License 2.0
5.95k stars 669 forks source link

@flow.global_function(type="train") can not be called after any training [Lazy mode] #3436

Open xiejiachen opened 4 years ago

xiejiachen commented 4 years ago
# mlp_mnist.py
import oneflow as flow
import oneflow.typing as tp

BATCH_SIZE = 100

@flow.global_function(type="train")
def train_job(
    images: tp.Numpy.Placeholder((BATCH_SIZE, 1, 28, 28), dtype=flow.float),
    labels: tp.Numpy.Placeholder((BATCH_SIZE,), dtype=flow.int32),
) -> tp.Numpy:
    with flow.scope.placement("cpu", "0:0"):
        initializer = flow.truncated_normal(0.1)
        reshape = flow.reshape(images, [images.shape[0], -1])
        hidden = flow.layers.dense(
            reshape,
            512,
            activation=flow.nn.relu,
            kernel_initializer=initializer,
            name="dense1",
        )
        logits = flow.layers.dense(
            hidden, 10, kernel_initializer=initializer, name="dense2"
        )
        loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits)

    lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.1])
    flow.optimizer.SGD(lr_scheduler, momentum=0).minimize(loss)

    return loss

check_point = flow.train.CheckPoint()
check_point.init()

(train_images, train_labels), (test_images, test_labels) = flow.data.load_mnist(
    BATCH_SIZE, BATCH_SIZE
)
for i, (images, labels) in enumerate(zip(train_images, train_labels)):
    loss = train_job(images, labels)
    if i % 20 == 0:
        print(loss.mean())

#Above is tutorial at  https://docs.oneflow.org/quick_start/quickstart_in_3_min.html
#Below is call flow.global function again!
@flow.global_function(type="train")
def train_job1(
    images: tp.Numpy.Placeholder((BATCH_SIZE, 1, 28, 28), dtype=flow.float),
    labels: tp.Numpy.Placeholder((BATCH_SIZE,), dtype=flow.int32),
) -> tp.Numpy:
    with flow.scope.placement("cpu", "0:0"):
        initializer = flow.truncated_normal(0.1)
        reshape = flow.reshape(images, [images.shape[0], -1])
        hidden = flow.layers.dense(
            reshape,
            512,
            activation=flow.nn.relu,
            kernel_initializer=initializer,
            name="dense1",
        )
        logits = flow.layers.dense(
            hidden, 10, kernel_initializer=initializer, name="dense2"
        )
        loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits)

    lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.1])
    flow.optimizer.SGD(lr_scheduler, momentum=0).minimize(loss)

    return loss

Error Report as

eager_oneflow_function: FAILED
    ("Current mode is NORMAL_MODE"[True] and "Eager execution enabled"[False])

lazy_oneflow_function: FAILED
    (("Current mode is NORMAL_MODE"[True] and (not "Eager execution enabled"[False])) and (not "Session initialized"[True]))
doombeaker commented 4 years ago

Do you run it in interactive mode such as jupyter notebook or ipython? By default, progrmas of OneFlow can only be started as scripts in "lazy mode" like this:

python mlp_mnist.py

If you insist on running OneFlow code in interactive mode, insert the following code at the beginning,

flow.enable_eager_execution(True)

so that run OneFlow prgrams in "eager mode".

"Eager mode" is experimental and on the way to be improved.

xiejiachen commented 4 years ago

Same code works at [eager mode].

Do you run it in interactive mode such as jupyter notebook or ipython? By default, progrmas of OneFlow can only be started as scripts in "lazy mode" like this:

python mlp_mnist.py

If you insist on running OneFlow code in interactive mode, insert the following code at the beginning,

flow.enable_eager_execution(True)

so that run OneFlow prgrams in "eager mode".

"Eager mode" is experimental and on the way to be improved.

That problems is not caused by ipython core, you can also try that with the codes I given in terminal.

xiejiachen commented 4 years ago

@lixinqi Report from slack, this condition is because of laze mode design. Is that better to give a warning such as

If you wanna try to call global_function again, please clear session firstly or use eager mode?

lixinqi commented 4 years ago

@lixinqi Report from slack, this condition is because of laze mode design. Is that better to give a warning such as

If you wanna try to call global_function again, please clear session firstly or use eager mode?

Agree