Capture_stdout causes crash, GPU, Lux.jl #2913

Dale-Black opened 1 month ago

Dale-Black commented 1 month ago


I am encountering GPU memory issues while training a Lux model (1.4 million parameters, input image size (128, 128, 96)) in a Pluto notebook on an A100 GPU. The issue occurs around the 50th epoch, resulting in Malt.TerminatedWorkerException() errors. However, the same code runs without issues when executed outside of Pluto.


Steps to Reproduce:

  1. Open the Pluto notebook:
  2. Run the notebook and observe the GPU memory usage and errors.

Expected Behavior:

The training should proceed without encountering GPU memory issues or Malt.TerminatedWorkerException() errors.

Actual Behavior:

The training encounters GPU memory issues and Malt.TerminatedWorkerException() errors around the 50th epoch when running in a Pluto notebook. The same code runs without issues when executed outside of Pluto.

Additional Context:

fonsp commented 1 month ago

Hey Dale!! Thanks for the clear bug report!

Do you see any errors in the terminal where you launched Pluto? Does fix it, or does this crash, but with additional error messages?

Dale-Black commented 1 month ago

Just tested using workspace_use_distributed_stdlib = true and got this error at around epoch 50

Worker 2 terminated.
Unhandled Task ERROR: EOFError: read end of file
 [1] (::Base.var"#wait_locked#739")(s::Sockets.TCPSocket, buf::IOBuffer, nb::Int64)
   @ Base ./stream.jl:947
 [2] unsafe_read(s::Sockets.TCPSocket, p::Ptr{UInt8}, nb::UInt64)
   @ Base ./stream.jl:955
 [3] unsafe_read
   @ ./io.jl:774 [inlined]
 [4] unsafe_read(s::Sockets.TCPSocket, p::Base.RefValue{NTuple{4, Int64}}, n::Int64)
   @ Base ./io.jl:773
 [5] read!
   @ ./io.jl:775 [inlined]
 [6] deserialize_hdr_raw
   @ ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/messages.jl:167 [inlined]
 [7] message_handler_loop(r_stream::Sockets.TCPSocket, w_stream::Sockets.TCPSocket, incoming::Bool)
   @ Distributed ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/process_messages.jl:172
 [8] process_tcp_streams(r_stream::Sockets.TCPSocket, w_stream::Sockets.TCPSocket, incoming::Bool)
   @ Distributed ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/process_messages.jl:133
 [9] (::Distributed.var"#103#104"{Sockets.TCPSocket, Sockets.TCPSocket, Bool})()
   @ Distributed ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/process_messages.jl:121

When using Malt, there are not any errors within the terminal

fonsp commented 1 month ago

Hm! That doesn't help, it just complains that it couldnt read data from a shut down worker.

I made a branch disable-logger-and-stdout where the log and stdout capture is disabled. Can you try on this branch?

pkg (@1.10)> add Pluto#disable-logger-and-stdout
Dale-Black commented 1 month ago

Interesting, of course the logging is now all in the terminal, but that branch works fine

fonsp commented 1 month ago

How about With regular released Pluto

Dale-Black commented 1 month ago

Sorry for the late reply, just getting around to this. It looks like that works. I wonder if the fact that I am port-forwarding Pluto on a remote cluster has anything to do with this issue? The current Pluto release with this launching setup now works, although it logs stuff in the terminal of course: = false, host = "", capture_stdout = false)

fonsp commented 1 month ago

Strange! That means that capture_stdout is the problem. When running Pluto with capture_stdout=false, is the stdout printed to your terminal different from what you get when running the notebook without Pluto?

I wonder if Lux has a problem with having stdout captured into a buffer (what Pluto does) instead of printing to stdout directly.

Can you try to run your code without Pluto, but using Pluto's stdout capturing system?

Step 1: open a REPL Step 2: run this code, these are some snippets from Pluto's source code. You could also do import Pluto.PlutoRunner.with_io_to_logs

import Logging

const default_stdout_iocontext = IOContext(devnull, 
    :color => true, 
    :limit => true, 
    :displaysize => (18, 75), 
    :is_pluto => false,

const stdout_log_level = Logging.LogLevel(-555) #

function _send_stdio_output!(output, loglevel)
    output_str = String(take!(output))
    if !isempty(output_str)
        Logging.@logmsg loglevel output_str

function with_io_to_logs(f::Function; enabled::Bool=true, loglevel::Logging.LogLevel=Logging.LogLevel(1))
    if !enabled
        return f()
    # Taken from with some modifications to make it log.

    # Original implementation from Documenter.jl (MIT license)
    # Save the default output streams.
    default_stdout = stdout
    default_stderr = stderr
    # Redirect both the `stdout` and `stderr` streams to a single `Pipe` object.
    pipe = Pipe()
    Base.link_pipe!(pipe; reader_supports_async = true, writer_supports_async = true)
    pe_stdout =
    pe_stderr =

    # Bytes written to the `pipe` are captured in `output` and eventually converted to a
    # `String`. We need to use an asynchronous task to continously tranfer bytes from the
    # pipe to `output` in order to avoid the buffer filling up and stalling write() calls in
    # user code.
    execution_done = Ref(false)
    output = IOBuffer()

    @async begin
        pipe_reader = Base.pipe_reader(pipe)
            while !eof(pipe_reader)
                write(output, readavailable(pipe_reader))

                # NOTE: we don't really have to wait for the end of execution to stream output logs
                #       so maybe we should just enable it?
                if execution_done[]
                    _send_stdio_output!(output, loglevel)
            _send_stdio_output!(output, loglevel)
        catch err
            @error "Failed to redirect stdout/stderr to logs"  exception=(err,catch_backtrace())
            if err isa InterruptException

    # To make the `display` function work.
    redirect_display = TextDisplay(IOContext(pe_stdout, default_stdout_iocontext))

    # Run the function `f`, capturing all output that it might have generated.
    # Success signals whether the function `f` did or did not throw an exception.
    result = try
        # Restore display
        catch e
            # This happens when the user calls `popdisplay()`, fine.
            # @warn "Pluto's display was already removed?" e

        execution_done[] = true

        # Restore the original output streams.


Step 3: save your notebook as script.jl

Step 4:

with_io_to_logs() do
Dale-Black commented 4 weeks ago

That also seems to work fine, no memory issues

fonsp commented 4 weeks ago

Can you try this:

julia> import Pluto

julia> using UUIDs

julia> nbid = cellid = uuid1()

julia> c = Channel(10_000);

julia> Pluto.PlutoRunner.setup_plutologger(nbid, c);

# after some setup, we now run the expression just like Pluto would, using the full log capture, timing, try catch, etc

julia> Pluto.PlutoRunner.run_expression(Main, :(include("script.jl")), nbid, cellid);
┌ Warning: @Dale you can ignore whatever it says here
└ @ Pluto.PlutoRunner ~/Documents/Pluto.jl/src/runner/PlutoRunner/src/PlutoRunner.jl:282

# see the logs that were captured:

julia> while isready(c)
fonsp commented 4 weeks ago

It's really hard to debug this because I don't have a GPU... hopefully my last comment will cause the error "outside of Pluto", in which case you could report it to the package devs. You could try to narrow down on the cause by commenting more and more lines in PlutoRunner.run_expression.

If this is not it, I will mark this as "unsolvable" by me because of limited time and resources, but maybe someone else can debug it. You just need to to modify PlutoRunner.run_expression, restart the notebook file (you don't need to fully restart Pluto), see what happens and repeat. The goal is to narrow down on which line causes the crash, and figure out why. It seems to be related to stdout/logging.

Dale-Black commented 4 weeks ago

Okay thank you, I will keep chipping away at this and keep you updated