Closed seachenjy closed 2 years ago
i made it ^_^
package main
import (
"fmt"
"io/ioutil"
tf "github.com/galeone/tensorflow/tensorflow/go"
"github.com/galeone/tensorflow/tensorflow/go/op"
)
func main() {
model_path := "./mobilenet_v2_140_224"
m, err := tf.LoadSavedModel(model_path, []string{"serve"}, nil) // 载入模型
if err != nil {
// 模型加载失败
fmt.Printf("err: %v", err)
}
// 打印出所有的Operator
// for _, op := range m.Graph.Operations() {
// fmt.Printf("Op name: %v \r\n", op.Name())
// }
tensor_x, err := makeTensorFromImage("./d.png")
if err != nil {
fmt.Printf("err: %s", err.Error())
return
}
s := m.Session
feeds := map[tf.Output]*tf.Tensor{
// operation name 需要根据你的模型入参来写
m.Graph.Operation("serving_default_input").Output(0): tensor_x,
}
fetches := []tf.Output{
// 输出层的name 也要根据你的模型写
m.Graph.Operation("StatefulPartitionedCall").Output(0),
}
result, err := s.Run(feeds, fetches, nil)
if err != nil {
// 模型预测失败
fmt.Printf("---err: %s ", err.Error())
}
fmt.Println(result[0].Value())
}
func makeTensorFromImage(filename string) (*tf.Tensor, error) {
bytes, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
tensor, err := tf.NewTensor(string(bytes))
if err != nil {
return nil, err
}
graph, input, output, err := makeTransformImageGraph("jpeg")
if err != nil {
return nil, err
}
session, err := tf.NewSession(graph, nil)
if err != nil {
return nil, err
}
defer session.Close()
normalized, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
return nil, err
}
return normalized[0], nil
}
func makeTransformImageGraph(imageFormat string) (graph *tf.Graph, input, output tf.Output, err error) {
const (
H, W = 224, 224
Mean = float32(117)
Scale = float32(1)
)
s := op.NewScope()
input = op.Placeholder(s, tf.String)
// Decode PNG or JPEG
var decode tf.Output
if imageFormat == "png" {
decode = op.DecodePng(s, input, op.DecodePngChannels(3))
} else {
decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
}
// Div and Sub perform (value-Mean)/Scale for each pixel
output = op.Div(s,
op.Sub(s,
// Resize to 224x224 with bilinear interpolation
op.ResizeBilinear(s,
// Create a batch containing a single image
op.ExpandDims(s,
// Use decoded pixel values
op.Cast(s, decode, tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{H, W})),
op.Const(s.SubScope("mean"), Mean)),
op.Const(s.SubScope("scale"), Scale))
graph, err = s.Finalize()
return graph, input, output, err
}
errors:
it not work,help me