microsoft / CsWin32

A source generator to add a user-defined set of Win32 P/Invoke methods and supporting types to a C# project.
MIT License
2k stars 84 forks source link

DnsServiceRegister is not generated correctly #1025

Closed hach-que closed 11 months ago

hach-que commented 11 months ago

Actual behavior

CsWin32 generates the function signature for DnsServiceRegister as:

[DllImport("DNSAPI.dll", ExactSpelling = true, SetLastError = true)]
[DefaultDllImportSearchPaths(DllImportSearchPath.System32)]
[SupportedOSPlatform("windows10.0.10240")]
public static extern unsafe uint DnsServiceRegister(in winmdroot.NetworkManagement.Dns.DNS_SERVICE_REGISTER_REQUEST pRequest, [Optional] winmdroot.NetworkManagement.Dns.DNS_SERVICE_CANCEL* pCancel);

This appears to be incorrect, as the DNS_SERVICE_REGISTER_REQUEST pointer needs to be fixed for the lifetime of the asynchronous call (i.e. the parameter should be DNS_SERVICE_REGISTER_REQUEST*).

Also, DNS_SERVICE_REGISTER_REQUEST is generated as a managed type, so at the moment you can't store it on the heap with Marshal.AllocHGlobal either.

DnsServiceDeRegister is also affected - this should also take a pointer, not the struct directly.

Expected behavior

The first parameter to DnsServiceRegister should be a pointer, and you should be able to put the DNS_SERVICE_REGISTER_REQUEST onto the heap via Marshal.AllocHGlobal.

Repro steps

  1. NativeMethods.txt content:

    DnsServiceRegister
    DnsServiceDeRegister
    DnsServiceBrowse
    DNS_QUERY_OPTIONS
    DNS_REQUEST_PENDING
    WIN32_ERROR
  2. NativeMethods.json content (if present):

    {
    "$schema": "https://aka.ms/CsWin32.schema.json",
    "public": true
    }
  3. Any of your own code that should be shared?

This will result in an access violation when run, due to DnsRegisterService trying to access data on the stack that no longer exists:

namespace Redpoint.AutoDiscovery
{
    extern alias SDWin64;
    using Redpoint.Concurrency;
    using SDWin64::Windows.Win32;
    using SDWin64::Windows.Win32.Foundation;
    using SDWin64::Windows.Win32.NetworkManagement.Dns;
    using System.Collections.Concurrent;
    using System.ComponentModel;
    using System.Runtime.InteropServices;
    using System.Runtime.Versioning;

    [SupportedOSPlatform("windows10.0.10240")]
    internal class Win64NetworkAutoDiscovery : INetworkAutoDiscovery
    {
        private static readonly ConcurrentDictionary<nint, InflightRegisterRequest> _inflightRegister = new ConcurrentDictionary<nint, InflightRegisterRequest>();
        private static readonly ConcurrentDictionary<nint, InflightDeRegisterRequest> _inflightDeRegister = new ConcurrentDictionary<nint, InflightDeRegisterRequest>();
        private static nint _nextId = 1000;

        private class InflightRegisterRequest
        {
            public required nint Id;
            public readonly Gate AsyncSemaphore = new Gate();
            public nint ServiceInstance;
            public DNS_SERVICE_REGISTER_REQUEST RegisterRequest;
            public Exception? ResultException;
        }

        private class InflightDeRegisterRequest
        {
            public required nint Id;
            public readonly Gate AsyncSemaphore = new Gate();
            public nint ServiceInstance;
            public DNS_SERVICE_REGISTER_REQUEST RegisterRequest;
            public Exception? ResultException;
        }

        private static unsafe InflightRegisterRequest StartRegisterService(string name, int port)
        {
            var instanceName = Marshal.StringToHGlobalUni(name);

            var requestPending = false;
            var serviceInstance = (DNS_SERVICE_INSTANCE*)Marshal.AllocHGlobal(sizeof(DNS_SERVICE_INSTANCE));
            serviceInstance->pszInstanceName = (char*)instanceName;
            serviceInstance->wPort = (ushort)port;
            try
            {
                var requestId = _nextId++;
                var registerRequestInstance = new DNS_SERVICE_REGISTER_REQUEST
                {
                    Version = (uint)DNS_QUERY_OPTIONS.DNS_QUERY_REQUEST_VERSION1,
                    InterfaceIndex = 0,
                    pServiceInstance = serviceInstance,
                    pRegisterCompletionCallback = EndRegisterService,
                    pQueryContext = (void*)requestId,
                    unicastEnabled = false
                };
                var request = new InflightRegisterRequest
                {
                    Id = requestId,
                    ServiceInstance = (nint)serviceInstance,
                    RegisterRequest = registerRequestInstance,
                };
                _inflightRegister[requestId] = request;
                try
                {
                    var result = PInvoke.DnsServiceRegister(registerRequestInstance, null);
                    if (result != PInvoke.DNS_REQUEST_PENDING)
                    {
                        throw new Win32Exception((int)result);
                    }
                    else
                    {
                        requestPending = true;
                        return request;
                    }
                }
                finally
                {
                    if (!requestPending)
                    {
                        _inflightRegister.TryRemove(requestId, out _);
                    }
                }
            }
            finally
            {
                if (!requestPending)
                {
                    Marshal.FreeHGlobal((nint)serviceInstance->pszInstanceName.Value);
                    Marshal.FreeHGlobal((nint)serviceInstance);
                }
            }
        }

        private static unsafe void EndRegisterService(
            uint status,
            void* queryContext,
            DNS_SERVICE_INSTANCE* instance)
        {
            var inflight = _inflightRegister[(nint)queryContext];
            if (status != (uint)WIN32_ERROR.ERROR_SUCCESS)
            {
                Marshal.FreeHGlobal((nint)((DNS_SERVICE_INSTANCE*)inflight.ServiceInstance)->pszInstanceName.Value);
                Marshal.FreeHGlobal(inflight.ServiceInstance);
                inflight.ResultException = new Win32Exception((int)status);
            }
            inflight.AsyncSemaphore.Unlock();
        }

        private static unsafe InflightDeRegisterRequest StartDeRegisterService(
            DNS_SERVICE_REGISTER_REQUEST registerRequestInstance,
            nint serviceInstanceRaw)
        {
            var serviceInstance = (DNS_SERVICE_INSTANCE*)serviceInstanceRaw;
            if (registerRequestInstance.pServiceInstance != serviceInstance)
            {
                throw new InvalidOperationException();
            }

            var requestPending = false;
            var requestId = _nextId++;
            var request = new InflightDeRegisterRequest
            {
                Id = requestId,
                ServiceInstance = (nint)serviceInstance,
                RegisterRequest = registerRequestInstance,
            };
            _inflightDeRegister[requestId] = request;
            try
            {
                registerRequestInstance.pRegisterCompletionCallback = EndDeRegisterService;
                var result = PInvoke.DnsServiceDeRegister(registerRequestInstance, null);
                if (result != PInvoke.DNS_REQUEST_PENDING)
                {
                    throw new Win32Exception((int)result);
                }
                else
                {
                    requestPending = true;
                    return request;
                }
            }
            finally
            {
                if (!requestPending)
                {
                    _inflightDeRegister.TryRemove(requestId, out _);
                    Marshal.FreeHGlobal((nint)serviceInstance->pszInstanceName.Value);
                    Marshal.FreeHGlobal(serviceInstanceRaw);
                }
            }
        }

        private static unsafe void EndDeRegisterService(
            uint status,
            void* queryContext,
            DNS_SERVICE_INSTANCE* instance)
        {
            var inflight = _inflightDeRegister[(nint)queryContext];
            if (status != (uint)WIN32_ERROR.ERROR_SUCCESS)
            {
                inflight.ResultException = new Win32Exception((int)status);
            }
            Marshal.FreeHGlobal((nint)((DNS_SERVICE_INSTANCE*)inflight.ServiceInstance)->pszInstanceName.Value);
            Marshal.FreeHGlobal(inflight.ServiceInstance);
            inflight.AsyncSemaphore.Unlock();
        }

        private class DnsDeregisterAsyncDisposable : IAsyncDisposable
        {
            private readonly DNS_SERVICE_REGISTER_REQUEST _request;
            private readonly nint _service;

            public DnsDeregisterAsyncDisposable(
                DNS_SERVICE_REGISTER_REQUEST request,
                nint service)
            {
                _request = request;
                _service = service;
            }

            public async ValueTask DisposeAsync()
            {
                var request = StartDeRegisterService(_request, _service);
                await request.AsyncSemaphore.WaitAsync(CancellationToken.None);
                _inflightRegister.TryRemove(request.Id, out _);
                if (request.ResultException != null)
                {
                    throw request.ResultException;
                }
            }
        }

        public async Task<IAsyncDisposable> RegisterServiceAsync(string name, int port, CancellationToken cancellationToken)
        {
            var request = StartRegisterService(name, port);
            await request.AsyncSemaphore.WaitAsync(cancellationToken);
            _inflightRegister.TryRemove(request.Id, out _);
            if (request.ResultException != null)
            {
                throw request.ResultException;
            }
            return new DnsDeregisterAsyncDisposable(request.RegisterRequest, request.ServiceInstance);
        }

        public IAsyncEnumerable<NetworkService> DiscoverServicesAsync(string name)
        {
            throw new NotImplementedException();
        }
    }
}

Context

hach-que commented 11 months ago

Ok, turns out this was caused by not using DnsServiceConstructInstance to allocate the DNS_SERVICE_INSTANCE.