GantMan / nsfw_model

Keras model of NSFW detector
Other
1.8k stars 279 forks source link

how to use the model in golang #101

Closed seachenjy closed 2 years ago

seachenjy commented 2 years ago
package main

import (
    "fmt"
    "io/ioutil"

    tf "github.com/galeone/tensorflow/tensorflow/go"
)

func main() {
    model_path := "./mobilenet_v2_140_224"
    m, err := tf.LoadSavedModel(model_path, []string{"serve"}, nil) // 载入模型
    if err != nil {
        // 模型加载失败
        fmt.Printf("err: %v", err)
    }
    // 打印出所有的Operator
    // for _, op := range m.Graph.Operations() {
    //  fmt.Printf("Op name: %v \r\n", op.Name())
    // }

    tensor_x, err := makeTensorFromeImage("./p1.jpeg")
    if err != nil {
        fmt.Printf("err: %s", err.Error())
        return
    }

    s := m.Session
    feeds := map[tf.Output]*tf.Tensor{
        // operation name 需要根据你的模型入参来写
        m.Graph.Operation("serving_default_input").Output(0): tensor_x,
    }
    fetches := []tf.Output{
        // 输出层的name 也要根据你的模型写
        m.Graph.Operation("StatefulPartitionedCall").Output(0),
    }

    result, err := s.Run(feeds, fetches, nil)
    if err != nil {
        // 模型预测失败
        fmt.Printf("---err: %s  ", err.Error())
    }
    fmt.Println(result)
}

func makeTensorFromeImage(filename string) (*tf.Tensor, error) {
    bytes, err := ioutil.ReadFile(filename)
    if err != nil {
        return nil, err
    }
    tensor, err := tf.NewTensor(string(bytes))
    if err != nil {
        return nil, err
    }
    return tensor, nil
}

errors:

---err: Expects arg[0] to be float but string is provided  []

it not work,help me

seachenjy commented 2 years ago

i made it ^_^

package main

import (
    "fmt"
    "io/ioutil"

    tf "github.com/galeone/tensorflow/tensorflow/go"
    "github.com/galeone/tensorflow/tensorflow/go/op"
)

func main() {
    model_path := "./mobilenet_v2_140_224"
    m, err := tf.LoadSavedModel(model_path, []string{"serve"}, nil) // 载入模型
    if err != nil {
        // 模型加载失败
        fmt.Printf("err: %v", err)
    }
    // 打印出所有的Operator
    // for _, op := range m.Graph.Operations() {
    //  fmt.Printf("Op name: %v \r\n", op.Name())
    // }

    tensor_x, err := makeTensorFromImage("./d.png")
    if err != nil {
        fmt.Printf("err: %s", err.Error())
        return
    }

    s := m.Session
    feeds := map[tf.Output]*tf.Tensor{
        // operation name 需要根据你的模型入参来写
        m.Graph.Operation("serving_default_input").Output(0): tensor_x,
    }
    fetches := []tf.Output{
        // 输出层的name 也要根据你的模型写
        m.Graph.Operation("StatefulPartitionedCall").Output(0),
    }

    result, err := s.Run(feeds, fetches, nil)
    if err != nil {
        // 模型预测失败
        fmt.Printf("---err: %s  ", err.Error())
    }
    fmt.Println(result[0].Value())
}

func makeTensorFromImage(filename string) (*tf.Tensor, error) {
    bytes, err := ioutil.ReadFile(filename)
    if err != nil {
        return nil, err
    }
    tensor, err := tf.NewTensor(string(bytes))
    if err != nil {
        return nil, err
    }
    graph, input, output, err := makeTransformImageGraph("jpeg")
    if err != nil {
        return nil, err
    }
    session, err := tf.NewSession(graph, nil)
    if err != nil {
        return nil, err
    }
    defer session.Close()
    normalized, err := session.Run(
        map[tf.Output]*tf.Tensor{input: tensor},
        []tf.Output{output},
        nil)
    if err != nil {
        return nil, err
    }
    return normalized[0], nil
}

func makeTransformImageGraph(imageFormat string) (graph *tf.Graph, input, output tf.Output, err error) {
    const (
        H, W  = 224, 224
        Mean  = float32(117)
        Scale = float32(1)
    )
    s := op.NewScope()
    input = op.Placeholder(s, tf.String)
    // Decode PNG or JPEG
    var decode tf.Output
    if imageFormat == "png" {
        decode = op.DecodePng(s, input, op.DecodePngChannels(3))
    } else {
        decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
    }
    // Div and Sub perform (value-Mean)/Scale for each pixel
    output = op.Div(s,
        op.Sub(s,
            // Resize to 224x224 with bilinear interpolation
            op.ResizeBilinear(s,
                // Create a batch containing a single image
                op.ExpandDims(s,
                    // Use decoded pixel values
                    op.Cast(s, decode, tf.Float),
                    op.Const(s.SubScope("make_batch"), int32(0))),
                op.Const(s.SubScope("size"), []int32{H, W})),
            op.Const(s.SubScope("mean"), Mean)),
        op.Const(s.SubScope("scale"), Scale))
    graph, err = s.Finalize()
    return graph, input, output, err
}