Azure / autorest.csharp

Extension for AutoRest (https://github.com/Azure/autorest) that generates C# code
MIT License
142 stars 165 forks source link

Multipart supported in DPG #2576

Closed AlexanderSher closed 5 months ago

AlexanderSher commented 2 years ago

Agrifood has multipart body. Right now, autorest.csharp just picks up the last body parameter. As a result, we generate incorrect samples and don't add required headers in request.

lirenhe commented 1 year ago

@AlexanderSher, will you work on this item?

chunyu3 commented 1 year ago

Move to GA

lirenhe commented 1 year ago

not a priority for now, move it into the backlog

chunyu3 commented 11 months ago

tsp

model AudioTranscriptionOptions {
  @doc("""
  The audio data to transcribe. This must be the binary content of a file in one of the supported media formats:
   flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm.
  """)
  @projectedName("csharp", "AudioData")
  file: bytes;

  @doc("""
  The requested format of the transcription response data, which will influence the content and detail of the result.
  """)
  @projectedName("json", "response_format")
  responseFormat?: AudioTranscriptionFormat;

  @doc("""
  The primary spoken language of the audio data to be transcribed, supplied as a two-letter ISO-639-1 language code
  such as 'en' or 'fr'.
  Providing this known input language is optional but may improve the accuracy and/or latency of transcription.
  """)
  language?: string;

  @doc("""
  An optional hint to guide the model's style or continue from a prior audio segment. The written language of the
  prompt should match the primary spoken language of the audio data.
  """)
  prompt?: string;

  @doc("""
  The sampling temperature, between 0 and 1.
  Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
  If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
  """)
  temperature?: float32;

  @doc("""
  The model to use for this transcription request.
  """)
  // Implementation note: developer-facing specification of deployment or model by clients should be controlled either
  // via an operation parameter or by this request body field -- but only one of those. This field should be hidden by
  // clients if operation parameters are used and populated into the request body on an as-needed basis.
  @projectedName("csharp", "InternalNonAzureModelName")
  `model`?: string;
}

@action("audio/transcriptions")
op getAudioTranscriptionAsPlainText is Azure.Core.Foundations.ResourceOperation<
  Deployment,
  AudioTranscriptionOptions,  // response_format must be one of: text, srt, vtt
  string,
  MultipartFormDataRequestHeadersTraits
>;

Option 1:

//AudioTranscriptionOptions.serialization.cs
public partial class AudioTranslationOptions
{
    void IUtf8JsonSerializable.Write(Utf8JsonWriter writer)
    {
        writer.WriteStartObject();
        writer.WritePropertyName("file"u8);
        writer.WriteBase64StringValue(AudioData.ToArray(), "D");
        if (Optional.IsDefined(ResponseFormat))
        {
            writer.WritePropertyName("response_format"u8);
            writer.WriteStringValue(ResponseFormat.Value.ToString());
        }
        if (Optional.IsDefined(Language))
        {
            writer.WritePropertyName("language"u8);
            writer.WriteStringValue(Language);
        }
        if (Optional.IsDefined(Prompt))
        {
            writer.WritePropertyName("prompt"u8);
            writer.WriteStringValue(Prompt);
        }
        if (Optional.IsDefined(Temperature))
        {
            writer.WritePropertyName("temperature"u8);
            writer.WriteNumberValue(Temperature.Value);
        }
        if (Optional.IsDefined(InternalNonAzureModelName))
        {
            writer.WritePropertyName("model"u8);
            writer.WriteStringValue(InternalNonAzureModelName);
        }
        writer.WriteEndObject();
    }
    internal virtual RequestContent ToRequestContent(ContentType contentType = "application/json")
    {
        /** TODO: current we only support two type of content-type (application/json, multipart/form-data)
        ** will add more if the model has other content-type
        **/
        if (contentType == "multipart/form-data") {
            var content = new MultipartFormDataContent(Guid.NewGuid().ToString());
            content.Add(RequestContent.Create(InternalNonAzureModelName), "model", null);
            content.Add(RequestContent.Create(AudioData.ToArray()), "file", "@file.wav", null);
            if (Optional.IsDefined(ResponseFormat))
            {
                content.Add(RequestContent.Create(ResponseFormat.ToString()), "response_format", null);
            }
            if (Optional.IsDefined(Prompt))
            {
                content.Add(RequestContent.Create(Prompt), "prompt", null);
            }
            if (Optional.IsDefined(Temperature))
            {
                content.Add(RequestContent.Create($"{Temperature}"), "temperature", null);
            }
            return content;
        } else {
            var content = new Utf8JsonRequestContent();
            content.JsonWriter.WriteObjectValue(this);
            return content;
        }
    }
}

//protocol

public virtual async Task<Response> GetAudioTranscriptionAsync(string deploymentId, RequestContent content, RequestContext context = null)
{
}

//convenience method

public virtual Response<AudioTranscription> GetAudioTranscription(string deploymentId, AudioTranscriptionOptions audioTranscriptionOptions, CancellationToken cancellationToken = default)
{
    Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId));
    Argument.AssertNotNull(audioTranscriptionOptions, nameof(audioTranscriptionOptions));

    using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription");
    scope.Start();

    audioTranscriptionOptions.InternalNonAzureModelName = deploymentId;

    RequestContext context = FromCancellationToken(cancellationToken);
    using RequestContent content = audioTranscriptionOptions.ToRequestContent("multipart/form-data");
    Response response = GetAudioTranscription(deploymentId, content, context);
     return Response.FromValue(AudioTranscription.FromResponse(response), response);
}

set the correct content-type in CreateXXXRequest method

internal HttpMessage CreateGetAudioTranscriptionRequest(string deploymentId, RequestContent content, RequestContext context)
{
    HttpMessage message = _pipeline.CreateMessage(context, ResponseClassifier200);
    Request request = message.Request;
    request.Method = RequestMethod.Post;
    request.Uri = GetUri(deploymentId, "audio/transcriptions");
    (content as MultipartFormDataContent).ApplyToRequest(request);
    return message;
}

cons and pros:

cons. it leverage current model serialization and can convert the each part to the its expected content-type. props: it is internal and customer cannot convert their model to correct content and call protocol method. So we need to leverage public serialization/ modelReadWrite feature.(option 3)

chunyu3 commented 10 months ago

Option 2: //protocol

public virtual async Task GetAudioTranscriptionAsync(string deploymentId, RequestContent content, RequestContext context = null)
{
}

//convenience method

public virtual Response<AudioTranscription> GetAudioTranscription(string deploymentId, AudioTranscriptionOptions audioTranscriptionOptions, CancellationToken cancellationToken = default)
{
    Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId));
    Argument.AssertNotNull(audioTranscriptionOptions, nameof(audioTranscriptionOptions));

    using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription");
    scope.Start();

    audioTranscriptionOptions.InternalNonAzureModelName = deploymentId;

    RequestContext context = FromCancellationToken(cancellationToken);
    using RequestContent content = RequestContent.Create(new MulitipartFormDataConverter(audioTranscriptionOptions));
    Response response = GetAudioTranscription(deploymentId, content, context);
     return Response.FromValue(AudioTranscription.FromResponse(response), response);
}
set the correct content-type in CreateXXXRequest method

internal HttpMessage CreateGetAudioTranscriptionRequest(string deploymentId, RequestContent content, RequestContext context)
{
    HttpMessage message = _pipeline.CreateMessage(context, ResponseClassifier200);
    Request request = message.Request;
    request.Method = RequestMethod.Post;
    request.Uri = GetUri(deploymentId, "audio/transcriptions");
   request.Headers.Add("Content-Type", content.ContentType); //need to update the RequestContent class to have a property `contentType`
    request.Content = content;
    return message;
}

cons and pros:

  1. performance: When convert model, we need to use reflection, which may cause performance issue. Need performance testing
  2. when each property of the model has different encoding/content-type which mean each parts has its own content-type, or each part has customized headers, this approach cannot handle it because the model structure loses the information.
chunyu3 commented 10 months ago

Option3: leverage ModelReadWrite feature. Similar to option1 and convert the model to Multipart/form-data in serialization.

add function in RequestContent to convert a model to RequestContent

public static RequestContent Create(IModel<T> model, ModelReaderWriterOptions options)
{
  return new ModelReadWriteRequestContent(model, options); 
}

Add a internal class ModelReadWriteRequestContent

//RequestContent.cs
public partial class RequestContent
{
  public string ContentType { get; }

  private class ModelReadWriteRequestContent<T> : RequestContent
  {
     private BinaryData? _data;
     public BinaryData Data => _data ??= GetData();
     private IModel<T> _model;

     public ModelReadWriteRequestContent(IModel<T> model, ModelReaderWriterOptions options)
     {
         _model = model;
         _options = options;
     }
     private BinaryData GetData()
     {
        BinaryData data = ModelReaderWriter.Write(_model, _options); // will call the write function in serialization file
        ContentType = data.ContentType;
        return data;
     }

     override WriteTo(Stream stream)
     {
        stream.Write(Data.ToArray());
     }
  }
}

Multipart in Core library System.CodeModel

//MulipartContent.cs
    public class MultipartContent
    {

        private readonly List<MultipartContentPart> _nestedContent;
        private readonly string _subtype;
        private readonly string _boundary;
        internal readonly Dictionary<string, string> _headers;

        public MultipartContent()
            : this("mixed", GetDefaultBoundary())
        { }

        public MultipartContent(string subtype)
            : this(subtype, GetDefaultBoundary())
        { }

        public MultipartContent(string subtype, string boundary)
        {
            ValidateBoundary(boundary);
            _subtype = subtype;

            // see https://www.ietf.org/rfc/rfc1521.txt page 29.
            _boundary = boundary.Contains(":") ? $"\"{boundary}\"" : boundary;
            _headers = new Dictionary<string, string>
            {
                [HttpHeader.Names.ContentType] = $"multipart/{_subtype}; boundary={_boundary}"
            };

            _nestedContent = new List<MultipartContentPart>();
        }
        public virtual void Add(BinaryData content)
        {
            Argument.AssertNotNull(content, nameof(content));
            AddInternal(content, null);
        }
        public virtual void Add(BinaryData content, Dictionary<string, string> headers)
        {
            Argument.AssertNotNull(content, nameof(content));
            Argument.AssertNotNull(headers, nameof(headers));

            AddInternal(content, headers);
        }

        public override void WriteTo(Stream stream, CancellationToken cancellationToken)
        {
             ..... 
         }
        public override bool TryComputeLength(out long length)
        {

        }

        private class MultipartContentPart
        {
            public readonly BinaryData Content;
            public Dictionary<string, string> Headers;

            public MultipartContentPart(BinaryData content, Dictionary<string, string> headers)
            {
                Content = content;
                Headers = headers;
            }
        }
    }

//MultipartFormDataContent.cs
    public class MultipartFormDataContent : MultipartContent
    {
        #region Fields

        private const string FormData = "form-data";

        #endregion Fields

        #region Construction

        public MultipartFormDataContent() : base(FormData)
        { }

        public MultipartFormDataContent(string boundary) : base(FormData, boundary)
        { }

        #endregion Construction

        public override void Add(BinaryData content)
        {
            Argument.AssertNotNull(content, nameof(content));
            AddInternal(content, null, null, null);
        }

        public override void Add(BinaryData content, Dictionary<string, string> headers)
        {
            Argument.AssertNotNull(content, nameof(content));
            Argument.AssertNotNull(headers, nameof(headers));

            AddInternal(content, headers, null, null);
        }

        public void Add(BinaryData content, string name, Dictionary<string, string> headers)
        {
            Argument.AssertNotNull(content, nameof(content));
            Argument.AssertNotNullOrWhiteSpace(name, nameof(name));

            AddInternal(content, headers, name, null);
        }
        public void Add(BinaryData content, string name, string fileName, Dictionary<string, string> headers)
        {
            Argument.AssertNotNull(content, nameof(content));
            Argument.AssertNotNullOrWhiteSpace(name, nameof(name));
            Argument.AssertNotNullOrWhiteSpace(fileName, nameof(fileName));

            AddInternal(content, headers, name, fileName);
        }

        private void AddInternal(BinaryData content, Dictionary<string, string> headers, string name, string fileName)
        {
            if (headers == null)
            {
                headers = new Dictionary<string, string>();
            }

            if (!headers.ContainsKey("Content-Disposition"))
            {
                var value = FormData;

                if (name != null)
                {
                    value = value + "; name=" + name;
                }
                if (fileName != null)
                {
                    value = value + "; filename=" + fileName;
                }

                headers.Add("Content-Disposition", value);
            }
            if(!headers.ContainsKey("Content-Type"))
           {
                var value = content.ContentType;
                 if (value != null) {
                      // set default content-type
                 }
                 headers.Add("Content-Type", value);
           }
            base.Add(content, headers);
        }
    }

convert the model to multipart content in serialization file

//AudioTranscriptionOptions.serialization.cs
public partial class AudioTranslationOptions {
        ......

        BinaryData IPersistableModel<AudioTranslationOptions>.Write(ModelReaderWriterOptions options)
        {
           var format = options.Format == "W" ? ((IPersistableModel<XmlInstanceData>)this).GetWireFormat(options) : options.Format;
            if(format  == ModelReaderWriterFormat.Json) {
                //write json
                return ModelReaderWriter.Write(this, options);
            } else if (format == ModelReaderWriterFormat.MultipartFormData) {
                //write multipart/form-data
                string boundary = Guid.NewGuid().ToString();
                using MemoryStream stream = new MemoryStream();
                /** there are maybe another multipart converter helper to add each part to BinaryData one by one.
                 ** or add each part directly.
                 ** in this demo code, current leverage Azure.Core.MultipartFormDataContent. This may changed.
                 ***/
                var content = new MultipartFormDataContent(boundary);
                content.Add(BinaryData.FromString(InternalNonAzureModelName), "model", null);
                content.Add(BinaryData.FromBytes(AudioData.ToArray()), "file", "@file.wav", null);
                if (Optional.IsDefined(ResponseFormat))
                {
                    content.Add(BinaryData.FromString(ResponseFormat.ToString()), "response_format", null);
                }
                if (Optional.IsDefined(Prompt))
                {
                    content.Add(BinaryData.FromString(Prompt), "prompt", null);
                }
                if (Optional.IsDefined(Temperature))
                {
                    content.Add(BinaryData.FromString($"{Temperature}"), "temperature", null);
                }
                content.WriteTo(stream, CancellationToken.default);
                BinaryData binaryData;
                if (stream.Position > int.MaxValue)
                {
                    binaryData = BinaryData.FromStream(stream);
                }
                else
                {
                    binaryData = new BinaryData(stream.GetBuffer().AsMemory(0, (int)stream.Position));
                }
                binaryData.contentType = $"multipart/{_subtype}; boundary={boundary}";
                return binaryData;
            } else if (ptions.Format == ModelReaderWriterFormat.Bicep) {
                //write bicep
            }
        }
}

protocol method

public virtual async Task GetAudioTranscriptionAsync(string deploymentId, RequestContent content, RequestContext context = null)
{
}

convenience method

public virtual Response<AudioTranscription> GetAudioTranscription(string deploymentId, AudioTranscriptionOptions audioTranscriptionOptions, CancellationToken cancellationToken = default)
{
    Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId));
    Argument.AssertNotNull(audioTranscriptionOptions, nameof(audioTranscriptionOptions));

    using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription");
    scope.Start();

    audioTranscriptionOptions.InternalNonAzureModelName = deploymentId;

    RequestContext context = FromCancellationToken(cancellationToken);
    using RequestContent content = RequestContent.Create(audioTranscriptionOptions, ModelReaderWriterOptions.MultipartFormData);
    Response response = GetAudioTranscription(deploymentId, content, context);
     return Response.FromValue(AudioTranscription.FromResponse(response), response);
}
//CreateXXXRequest
internal HttpMessage CreateGetAudioTranscriptionRequest(string deploymentId, RequestContent content, RequestContext context)
{
    HttpMessage message = _pipeline.CreateMessage(context, ResponseClassifier200);
    Request request = message.Request;
    request.Method = RequestMethod.Post;
    request.Uri = GetUri(deploymentId, "audio/transcriptions");
   request.Headers.Add("Content-Type", content.ContentType); 
    request.Content = content;
    return message;
}

How to use the

var client = new MyClient(...);
var model = new AudioTranscriptionOptions (1, "x");
client.GetAudioTranscription(id, model, new RequestContext()); //call implicit cast from MyModel -> RequestContent assues "W", call protocol method
client.GetAudioTranscription(id, RequestContent.Create(model, new ModelReaderWriterOptions("MultipartFormData"), new RequestContext()); //call protocol method

//call convenience method
client.GetAudioTranscription(id, model);
chunyu3 commented 5 months ago

duplicated https://github.com/Azure/autorest.csharp/issues/4337