robbiet480 / go.nut

A Golang library for interacting with NUT (Network UPS Tools)
https://godoc.org/github.com/robbiet480/go.nut
MIT License
30 stars 15 forks source link

Addition of Socket Timeout Support (incl. Patch) #8

Closed AdamLeyshon closed 2 years ago

AdamLeyshon commented 2 years ago

Hi everyone, I'm raising this issue on behalf of @Malinskiy Anton Malinskiy

He has created a patch which adds socket timeouts and is being depended on for a Telegraf plugin that implements monitoring of UPS via nut.go here https://github.com/influxdata/telegraf/pull/9890 and the Telegraf maintainers would prefer the plugin depend on the original repo rather than a fork that creates another point of failure.

timeout.txt <- (txt file as Github won't let you upload patch files 🤔 ) which originates from https://github.com/Malinskiy/go.nut/commit/bb4669c66e1ed80138a6326b3f5185f063f07d46

so some changes will probably needed to repoint the code back at this module, rather than his fork. I'm a beginner at Go, I could give it a try, but I think I'd rather let someone who is more adept than myself do it for fear of introducing a breaking change.

From bb4669c66e1ed80138a6326b3f5185f063f07d46 Mon Sep 17 00:00:00 2001
From: Anton Malinskiy <anton@malinskiy.com>
Date: Sun, 3 Oct 2021 19:56:06 +1100
Subject: [PATCH] feat(nut): implement socket timeouts

---
 example_test.go |  7 +++----
 go.mod          |  3 +++
 nut.go          | 47 +++++++++++++++++++++++++++--------------------
 3 files changed, 33 insertions(+), 24 deletions(-)
 create mode 100644 go.mod

diff --git a/example_test.go b/example_test.go
index 6a23689..31936a1 100644
--- a/example_test.go
+++ b/example_test.go
@@ -2,17 +2,16 @@ package nut

 import (
        "fmt"
-
-       nut "github.com/robbiet480/go.nut"
+       "time"
 )

 // This example connects to NUT, authenticates and returns the first UPS listed.
 func ExampleGetUPSList() {
-       client, connectErr := nut.Connect("127.0.0.1")
+       client, connectErr := Connect("127.0.0.1", 10*time.Second, 30*time.Second)
        if connectErr != nil {
                fmt.Print(connectErr)
        }
-       _, authenticationError = client.Authenticate("username", "password")
+       _, authenticationError := client.Authenticate("username", "password")
        if authenticationError != nil {
                fmt.Print(authenticationError)
        }
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..cba4e3e
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,3 @@
+module github.com/Malinskiy/go.nut
+
+go 1.17
diff --git a/nut.go b/nut.go
index 34fd406..9705867 100644
--- a/nut.go
+++ b/nut.go
@@ -8,32 +8,33 @@ import (
        "fmt"
        "net"
        "strings"
+       "time"
 )

 // Client contains information about the NUT server as well as the connection.
 type Client struct {
-       Version         string
-       ProtocolVersion string
-       Hostname        net.Addr
-       conn            *net.TCPConn
+       opTimeout time.Duration
+       conn      net.Conn
 }

 // Connect accepts a hostname/IP string and creates a connection to NUT, returning a Client.
-func Connect(hostname string) (Client, error) {
-       tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:3493", hostname))
+func Connect(hostname string, connectTimeout time.Duration, opTimeout time.Duration) (*Client, error) {
+       _, _, err := net.SplitHostPort(hostname)
        if err != nil {
-               return Client{}, err
+               hostname = net.JoinHostPort(hostname, "3493")
        }
-       conn, err := net.DialTCP("tcp", nil, tcpAddr)
+       d := net.Dialer{
+               Timeout: connectTimeout,
+       }
+       conn, err := d.Dial("tcp", hostname)
        if err != nil {
-               return Client{}, err
+               return nil, err
        }
-       client := Client{
-               Hostname: conn.RemoteAddr(),
-               conn:     conn,
+
+       client := &Client{
+               opTimeout: opTimeout,
+               conn:      conn,
        }
-       client.GetVersion()
-       client.GetNetworkProtocolVersion()
        return client, nil
 }

@@ -55,6 +56,10 @@ func (c *Client) ReadResponse(endLine string, multiLineResponse bool) (resp []st
        response := []string{}

        for {
+               err = c.conn.SetReadDeadline(time.Now().Add(c.opTimeout))
+               if err != nil {
+                       return nil, err
+               }
                line, err := connbuff.ReadString('\n')
                if err != nil {
                        return nil, fmt.Errorf("error reading response: %v", err)
@@ -79,18 +84,22 @@ func (c *Client) SendCommand(cmd string) (resp []string, err error) {
        if strings.HasPrefix(cmd, "USERNAME ") || strings.HasPrefix(cmd, "PASSWORD ") || strings.HasPrefix(cmd, "SET ") || strings.HasPrefix(cmd, "HELP ") || strings.HasPrefix(cmd, "VER ") || strings.HasPrefix(cmd, "NETVER ") {
                endLine = "OK\n"
        }
-       _, err = fmt.Fprint(c.conn, cmd)
+       err = c.conn.SetWriteDeadline(time.Now().Add(c.opTimeout))
+       if err != nil {
+               return nil, err
+       }
+       _, err = c.conn.Write([]byte(cmd))
        if err != nil {
-               return []string{}, err
+               return nil, err
        }

        resp, err = c.ReadResponse(endLine, strings.HasPrefix(cmd, "LIST "))
        if err != nil {
-               return []string{}, err
+               return nil, err
        }

        if strings.HasPrefix(resp[0], "ERR ") {
-               return []string{}, errorForMessage(strings.Split(resp[0], " ")[1])
+               return nil, errorForMessage(strings.Split(resp[0], " ")[1])
        }

        return resp, nil
@@ -141,13 +150,11 @@ func (c *Client) Help() (string, error) {
 // GetVersion returns the the version of the server currently in use.
 func (c *Client) GetVersion() (string, error) {
        versionResponse, err := c.SendCommand("VER")
-       c.Version = versionResponse[0]
        return versionResponse[0], err
 }

 // GetNetworkProtocolVersion returns the version of the network protocol currently in use.
 func (c *Client) GetNetworkProtocolVersion() (string, error) {
        versionResponse, err := c.SendCommand("NETVER")
-       c.ProtocolVersion = versionResponse[0]
        return versionResponse[0], err
 }
robbiet480 commented 2 years ago

This patch looks good as is. Feel free to submit as a PR. It can be "your" first Go contribution!

AdamLeyshon commented 2 years ago

I've opened a PR, did my best to merge the changes of the custom port that was added.