dotnet / roslyn

The Roslyn .NET compiler provides C# and Visual Basic languages with rich code analysis APIs.
https://docs.microsoft.com/dotnet/csharp/roslyn-sdk/
MIT License
19.11k stars 4.04k forks source link

Attribute async state machines for better introspection #38994

Open bencyoung opened 5 years ago

bencyoung commented 5 years ago

Hi,

I've recently written a PoC of an async builder that allows you to suspend the thread of execution to disk and rehydrate at a later point, or even on a different machine. It appears to work ok (as long as the captures are serializable obviously), but writing the code to demangle the state machine fields was tricky and I'm pretty sure it's not right now. I've attached the code but it would be really nice if the state machine builder put some attributes on the fields so we can inspect them. E.g.

[StateMachineState] on the integer state field [StateMachineInput("i")] for the method arguments [StateMachineLocal("j")] for captured locals [StateMachineAwaiter(index)] or similar for the awaiters

I don't think it needs a lot more. Having exposed the state machine builders I'm not sure the internals of state machines could be changed easily so I don't think this is adding much of a backwards compatibility issue.

On a side note, is it possible to remove the debug bypass on the statemachines to see what's happening under the cover more easily?

Thanks, Ben Young

using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Tasks;

namespace MethodBuilderTest
{
    class Program
    {
        static void Main(string[] args)
        {
            Console.WriteLine("Started");
            Flow.Run(() => MyFlowData());
        }

        static async FlowData MyFlowData()
        {
            int i = 0;
            var numbers = new List<int>();
            while (true)
            {
                await Letters(i++);
                numbers.Add(i);
                await Flow.Checkpoint();
                await Task.Delay(TimeSpan.FromSeconds(1));
                Console.WriteLine("  " + numbers.Count);
            }
        }

        static async FlowData Letters(int i)
        {
            for (int j = 0; j < 5; ++j)
            {
                Console.WriteLine(i.ToString() + " " + (char)('A' + j));
                await Flow.Checkpoint();
                await Task.Delay(TimeSpan.FromSeconds(0.5));
            }
        }
    }

    static class Flow
    {
        internal static FlowContext FlowContext { get; set; }

        public static void Run(Func<FlowData> data)
        {
            FlowContext = new FlowContext();
            try
            {
                data().GetAwaiter().GetResult();
            }
            finally
            {
                FlowContext = null;
            }
        }

        public static FlowData Checkpoint()
        {
            var data = new FlowData();
            data.SetResult();
            return data;
        }
    }

    class FlowContext
    {
    }

    struct Void
    {
    }

    [AsyncMethodBuilder(typeof(FlowDataMethodBuilder))]
    class FlowData : INotifyCompletion
    {
        private ExceptionDispatchInfo exception;

        [JsonProperty]
        private bool hasResult;

        public FlowData GetAwaiter() => this;

        public void GetResult()
        {
            if (this.exception != null)
            {
                this.exception.Throw();
            }

            while (!this.hasResult)
            {
                Thread.Sleep(0);
            }
        }

        public bool IsCompleted => this.exception != null;

        public void OnCompleted(Action completion)
        {
            Task.Run(completion);
        }

        internal void SetException(Exception e)
        {
            this.exception = ExceptionDispatchInfo.Capture(e);
        }

        internal void SetResult()
        {
            this.hasResult = true;
        }
    }

    [AsyncMethodBuilder(typeof(FlowDataMethodBuilder))]
    class StateMachineFlowData : FlowData
    {
        public StateMachineFlowData(string name)
        {
            this.stateMachineName = name;
        }

        [JsonProperty]
        internal string stateMachineName;
    }

    class FlowState<TStateMachine> 
    {
        internal TStateMachine stateMachine;

        public void Debug()
        {
            Console.WriteLine($"{this.stateMachine.GetType().IsClass}");
            Console.WriteLine(string.Join(
                Environment.NewLine,
                this.Capture()
                    .Select(f => $"{f.Key} {f.Value}")));
        }

        public IReadOnlyDictionary<string, object> Capture()
        {
            return this.stateMachine
                    .GetType()
                    .GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
                    .Select(f => new { Name = DemangleName(f.Name), Value = f.GetValue(this.stateMachine), Type = f.FieldType })
                    .Where(f => !string.IsNullOrEmpty(f.Name) && !typeof(ICriticalNotifyCompletion).IsAssignableFrom(f.Type))
                    .ToDictionary(f => f.Name, f => f.Value);
        }

        public void RestoreState<TBuilder>(TBuilder builder, IReadOnlyDictionary<string, object> state)
        {
            var tr = __makeref(this.stateMachine);
            foreach (var field in this.stateMachine
                    .GetType()
                    .GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
                    .Select(f => new { Name = DemangleName(f.Name), Field = f })
                    .Where(f => !string.IsNullOrEmpty(f.Name)))
            { 
                if(state.TryGetValue(field.Name, out var value) && value != null)
                {
                    field.Field.SetValueDirect(tr, field.Field.FieldType.IsAssignableFrom(value.GetType()) ? value : Convert.ChangeType(value, field.Field.FieldType));
                }
            }

            this.stateMachine.GetType().GetField("<>t__builder").SetValueDirect(tr, builder);
        }

        private static string DemangleName(string name)
        {
            if(name.StartsWith("<>"))
            {
                if(name.Contains("__state"))
                {
                    return "state!";
                }
                else if(name.StartsWith("<>u_"))
                {
                    return "awaiter" + name.Substring("<>u_".Length) + "!";
                }

                return null;
            }
            else if (name.StartsWith("<"))
            {
                return name.Substring(1, name.IndexOf('>', 1) - 1) + "_local!";
            }

            return name;
        }
    }

    class FlowDataMethodBuilder
    {
        private readonly static JsonSerializer Serializer = JsonSerializer.Create(new JsonSerializerSettings
        {
            PreserveReferencesHandling = PreserveReferencesHandling.All,
            TypeNameHandling = TypeNameHandling.All,
            Formatting = Formatting.Indented
        });

        private string fileName;

        private Type stateMachineType;

        private object state;

        private StateMachineFlowData result;

        public static FlowDataMethodBuilder Create() => new FlowDataMethodBuilder();

        public FlowData Task => this.result;

        public void Start<TStateMachine>(ref TStateMachine stateMachine)
            where TStateMachine : IAsyncStateMachine
        {
            if(Flow.FlowContext == null)
            {
                throw new InvalidOperationException("No flow context");
            }

            this.stateMachineType = typeof(TStateMachine);
            this.fileName = this.FileName<TStateMachine>();
            this.result = new StateMachineFlowData(this.stateMachineType.FullName);

            var flow = EnsureState(ref stateMachine);

            this.RestoreState(flow);

            flow.stateMachine.MoveNext();
        }

        public void SetStateMachine(IAsyncStateMachine stateMachine)
        {
        }

        public void SetException(Exception exception)
        {
            if (this.result == null)
            {
                throw new InvalidOperationException("Task should be active already");
            }

            this.result.SetException(exception);
            File.Delete(this.fileName);
        }

        public void SetResult() 
        {
            if(this.result == null)
            {
                throw new InvalidOperationException("Task should be active already");
            }

            this.result.SetResult();
            File.Delete(this.fileName);
        }

        public void AwaitOnCompleted<TAwaiter, TStateMachine>(
            ref TAwaiter awaiter, ref TStateMachine stateMachine)
            where TAwaiter : INotifyCompletion
            where TStateMachine : IAsyncStateMachine => this.HandleCompleted(ref awaiter, ref stateMachine, false);

        public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
            ref TAwaiter awaiter, ref TStateMachine stateMachine)
            where TAwaiter : ICriticalNotifyCompletion
            where TStateMachine : IAsyncStateMachine => this.HandleCompleted(ref awaiter, ref stateMachine, true);

        internal void HandleCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine, bool useCritical)
            where TAwaiter : INotifyCompletion
            where TStateMachine : IAsyncStateMachine
        {
            if (Flow.FlowContext == null)
            {
                throw new InvalidOperationException("No flow context");
            }

            var flow = EnsureState(ref stateMachine);

            //flow.Debug();

            if (awaiter is FlowData)
            {
                var writeState = flow.Capture();
                using (var writer = File.CreateText(this.fileName))
                {
                    Serializer.Serialize(writer, writeState);
                }

                if (Flow.FlowContext == null)
                {
                    throw new InvalidOperationException("No flow context set");
                }

                awaiter.OnCompleted(() => flow.stateMachine.MoveNext());
            }
            else
            {
                // TODO: Disable flow context here so async methods can't call back into flow
                if (useCritical)
                {
                    ((ICriticalNotifyCompletion)awaiter).UnsafeOnCompleted(() =>
                    {
                        flow.stateMachine.MoveNext();
                    });
                }
                else
                {
                    awaiter.OnCompleted(() =>
                    {
                        flow.stateMachine.MoveNext();
                    });
                }
            }
        }

        private void RestoreState<TStateMachine>(FlowState<TStateMachine> flow)
        {
            if (File.Exists(this.fileName))
            {
                using (var reader = File.OpenRead(fileName))
                {
                    var state = Serializer.Deserialize<Dictionary<string, object>>(new JsonTextReader(new StreamReader(reader)));

                    var subState = state.Where(o => o.Value is StateMachineFlowData).SingleOrDefault();
                    if(subState.Value != null)
                    {
                        var subBuilder = Create();
                        var stateMachineType = Type.GetType(((StateMachineFlowData)subState.Value).stateMachineName);
                        var method = typeof(FlowDataMethodBuilder).GetMethod("Start").MakeGenericMethod(stateMachineType);
                        var childStateMachine = Activator.CreateInstance(stateMachineType);
                        method.Invoke(subBuilder, new[] { childStateMachine });
                        state[subState.Key] = subBuilder.Task;
                    }

                    flow.RestoreState(this, state);
                }
            }
        }

        private string FileName<TStateMachine>() => string.Join('_', typeof(TStateMachine).Name.Split(Path.GetInvalidFileNameChars())) + ".json";

        private FlowState<TStateMachine> EnsureState<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine
        {
            FlowState<TStateMachine> flow;
            if (this.state == null)
            {
                flow = new FlowState<TStateMachine>();
                this.state = flow;
            }
            else if (this.state is FlowState<TStateMachine> existingFlow)
            {
                flow = existingFlow;
            }
            else
            {
                throw new InvalidOperationException("Changing statemachine is not supported");
            }

            flow.stateMachine = stateMachine;
            return flow;
        }
    }
}
gafter commented 5 years ago

@bencyoung I don't think this is something we'd devote resources to. However, please let us know if you want to work on making these changes to the Roslyn compilers, and we can discuss if we want to support that.