golang / protobuf

Go support for Google's protocol buffers
BSD 3-Clause "New" or "Revised" License
9.8k stars 1.58k forks source link

type-safe Extensions API using Go Generics #1643

Open stapelberg opened 3 months ago

stapelberg commented 3 months ago

Splitting this suggestion by @dsnet out of the https://go.dev/cl/607995 review discussion:

When designing the extensions API, I was eyeing Go generics (which was still in the works) and it seems that we're finally allowed to use it since the latest language version seems to be Go 1.20.

One possible API is:

package proto

type Extension[M Message, V any] struct {
    protoreflect.ExtensionType
}

func (x Extension[M, V]) Has(m M) bool {
    return proto.HasExtension(m, x.ExtensionType)
}

func (x Extension[M, V]) Get(m M) V {
    return proto.GetExtension(m, x.ExtensionType).(V)
}

func (x Extension[M, V]) Set(m M, v V) {
    proto.SetExtension(m, x.ExtensionType, v)
}

func (x Extension[M, V]) Clear(m M) {
    proto.ClearExtension(m, x.ExtensionType)
}

And for all the generated E_xxx variables, we can have them wrap the underlying protoimpl.ExtensionInfo as a proto.Extension. For example:

var E_OptExtBool = proto.Extension[*Extensions, *bool]{&file_internal_testprotos_textpb2_test_proto_extTypes[0]}

I believe we can change the underlying concrete type of E_OptExtBool, since it was always exposing the unexported protoimpl.ExtensionInfo type so long as we keep all the original methods.

Technically, protomipl.ExtensionInfo is "exported", but the entire protomipl package has a big scary warning:

WARNING: This package should only ever be imported by generated messages. The compatibility agreement covers nothing except for functionality needed to keep existing generated messages operational. Breakages that occur due to unauthorized usages of this package are not the author's responsibility.

I gave this a quick try like so:

diff --git i/cmd/protoc-gen-go/internal_gengo/main.go w/cmd/protoc-gen-go/internal_gengo/main.go
index dfd16b27..e906b996 100644
--- i/cmd/protoc-gen-go/internal_gengo/main.go
+++ w/cmd/protoc-gen-go/internal_gengo/main.go
@@ -761,7 +761,7 @@ func genExtensions(g *protogen.GeneratedFile, f *fileInfo) {
    g.P("var ", extensionTypesVarName(f), " = []", protoimplPackage.Ident("ExtensionInfo"), "{")
    for _, x := range f.allExtensions {
        g.P("{")
-       g.P("ExtendedType: (*", x.Extendee.GoIdent, ")(nil),")
+       g.P("ExtendedType: (*", x.Extendee.GoIdent, ")(nil),") // ExtendedType:  (*base.BaseMessage)(nil),
        goType, pointer := fieldGoType(g, f, x.Extension)
        if pointer {
            goType = "*" + goType
@@ -792,6 +792,11 @@ func genExtensions(g *protogen.GeneratedFile, f *fileInfo) {
        g.P("// Extension fields to ", target, ".")
        g.P("var (")
        for _, x := range allExtensionsByTarget[target] {
+           goType, pointer := fieldGoType(g, f, x.Extension)
+           if pointer {
+               goType = "*" + goType
+           }
+
            xd := x.Desc
            typeName := xd.Kind().String()
            switch xd.Kind() {
@@ -812,7 +817,8 @@ func genExtensions(g *protogen.GeneratedFile, f *fileInfo) {
                x.Desc.ParentFile(),
                x.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
            g.P(leadingComments,
-               "E_", x.GoIdent, " = &", extensionTypesVarName(f), "[", allExtensionsByPtr[x], "]",
+               "E_", x.GoIdent, " = ", protoPackage.Ident("Extension"),
+               "[*", x.Extendee.GoIdent, ", ", goType, "]{&", extensionTypesVarName(f), "[", allExtensionsByPtr[x], "]}",
                trailingComment(x.Comments.Trailing))
        }
        g.P(")")
diff --git i/proto/extension.go w/proto/extension.go
index d248f292..e478d802 100644
--- i/proto/extension.go
+++ w/proto/extension.go
@@ -8,6 +8,26 @@ import (
    "google.golang.org/protobuf/reflect/protoreflect"
 )

+type Extension[M Message, V any] struct {
+   protoreflect.ExtensionType
+}
+
+func (x Extension[M, V]) Has(m M) bool {
+   return HasExtension(m, x.ExtensionType)
+}
+
+func (x Extension[M, V]) Get(m M) V {
+   return GetExtension(m, x.ExtensionType).(V)
+}
+
+func (x Extension[M, V]) Set(m M, v V) {
+   SetExtension(m, x.ExtensionType, v)
+}
+
+func (x Extension[M, V]) Clear(m M) {
+   ClearExtension(m, x.ExtensionType)
+}
+
 // HasExtension reports whether an extension field is populated.
 // It returns false if m is invalid or if xt does not extend m.
 func HasExtension(m Message, xt protoreflect.ExtensionType) bool {

There are a number of test failures. The most concerning one probably is:

--- FAIL: TestHasExtensionNoAlloc (0.00s)
    --- FAIL: TestHasExtensionNoAlloc/Nil (0.00s)
        extension_test.go:156: proto.HasExtension should not allocate, but allocated 1.00x per run
    --- FAIL: TestHasExtensionNoAlloc/Eager (0.00s)
        extension_test.go:156: proto.HasExtension should not allocate, but allocated 1.00x per run
    --- FAIL: TestHasExtensionNoAlloc/Lazy (0.00s)
        extension_test.go:156: proto.HasExtension should not allocate, but allocated 1.00x per run

Needs more investigation. (Inside of Google, there is an additional roadblock to landing this, which is the protodeps indirection package that uses type aliases, which don’t support generics (yet).)

puellanivis commented 3 months ago

Looking forward to how generics can simplify some of the code in protobufs, where generics could not just “enable metaprogramming” but really provide stronger type safety at compile time. 👍

Adding a bit to this:

there is an additional roadblock to landing this, which is the protodeps indirection package that uses type aliases, which don’t support generics (yet).

I have had luck with type aliases that fold a generic’s type parameter into it. Though, there are still “leaks” here and there, where I run into things like “cannot infer T” when it is otherwise a type parameter of a type it darn well does know. 🤷‍♀️ And/or simply isn’t used for the particular instance of the generic type.