feat: support Var API (#31)
diff --git a/go.mod b/go.mod
index 18b7b72..2a708fa 100644
--- a/go.mod
+++ b/go.mod
@@ -4,7 +4,7 @@
require (
github.com/ReneKroon/ttlcache/v2 v2.4.0
- github.com/api7/ext-plugin-proto v0.2.1
+ github.com/api7/ext-plugin-proto v0.3.0
github.com/google/flatbuffers v2.0.0+incompatible
github.com/spf13/cobra v1.1.3
github.com/stretchr/testify v1.7.0
diff --git a/go.sum b/go.sum
index 54dfc62..f58ccbe 100644
--- a/go.sum
+++ b/go.sum
@@ -26,6 +26,8 @@
github.com/api7/ext-plugin-proto v0.2.0/go.mod h1:8dbdAgCESeqwZ0IXirbjLbshEntmdrAX3uet+LW3jVU=
github.com/api7/ext-plugin-proto v0.2.1 h1:NRz4CxPM10KPHAJSv+5jcOMjQBJN8mninu9V6O62Mxw=
github.com/api7/ext-plugin-proto v0.2.1/go.mod h1:8dbdAgCESeqwZ0IXirbjLbshEntmdrAX3uet+LW3jVU=
+github.com/api7/ext-plugin-proto v0.3.0 h1:exofn/9DIPpqMAu033M7TBXsWr/O0LUcIKuLUCF58Xk=
+github.com/api7/ext-plugin-proto v0.3.0/go.mod h1:8dbdAgCESeqwZ0IXirbjLbshEntmdrAX3uet+LW3jVU=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
diff --git a/internal/http/request.go b/internal/http/request.go
index 3728d72..09ffff7 100644
--- a/internal/http/request.go
+++ b/internal/http/request.go
@@ -18,22 +18,31 @@
package http
import (
+ "encoding/binary"
"net"
"net/http"
"net/url"
"reflect"
"sync"
- pkgHTTP "github.com/apache/apisix-go-plugin-runner/pkg/http"
"github.com/api7/ext-plugin-proto/go/A6"
+ ei "github.com/api7/ext-plugin-proto/go/A6/ExtraInfo"
hrc "github.com/api7/ext-plugin-proto/go/A6/HTTPReqCall"
flatbuffers "github.com/google/flatbuffers/go"
+
+ "github.com/apache/apisix-go-plugin-runner/internal/util"
+ "github.com/apache/apisix-go-plugin-runner/pkg/common"
+ pkgHTTP "github.com/apache/apisix-go-plugin-runner/pkg/http"
+ "github.com/apache/apisix-go-plugin-runner/pkg/log"
)
type Request struct {
// the root of the flatbuffers HTTPReqCall Request msg
r *hrc.Req
+ conn net.Conn
+ extraInfoHeader []byte
+
path []byte
hdr *Header
@@ -41,6 +50,8 @@
args url.Values
rawArgs url.Values
+
+ vars map[string][]byte
}
func (r *Request) ConfToken() uint32 {
@@ -118,10 +129,44 @@
return r.args
}
+func (r *Request) Var(name string) ([]byte, error) {
+ if r.vars == nil {
+ r.vars = map[string][]byte{}
+ }
+
+ var v []byte
+ var found bool
+
+ if v, found = r.vars[name]; !found {
+ var err error
+
+ builder := util.GetBuilder()
+ varName := builder.CreateString(name)
+ ei.VarStart(builder)
+ ei.VarAddName(builder, varName)
+ varInfo := ei.VarEnd(builder)
+ v, err = r.askExtraInfo(builder, ei.InfoVar, varInfo)
+ util.PutBuilder(builder)
+
+ if err != nil {
+ return nil, err
+ }
+
+ r.vars[name] = v
+ }
+ return v, nil
+}
+
func (r *Request) Reset() {
r.path = nil
r.hdr = nil
r.args = nil
+
+ r.vars = nil
+ r.conn = nil
+
+ // Keep the fields below
+ // r.extraInfoHeader = nil
}
func (r *Request) FetchChanges(id uint32, builder *flatbuffers.Builder) bool {
@@ -230,6 +275,64 @@
return true
}
+func (r *Request) BindConn(c net.Conn) {
+ r.conn = c
+}
+
+func (r *Request) askExtraInfo(builder *flatbuffers.Builder,
+ infoType ei.Info, info flatbuffers.UOffsetT) ([]byte, error) {
+
+ ei.ReqStart(builder)
+ ei.ReqAddInfoType(builder, infoType)
+ ei.ReqAddInfo(builder, info)
+ eiRes := ei.ReqEnd(builder)
+ builder.Finish(eiRes)
+
+ c := r.conn
+ if len(r.extraInfoHeader) == 0 {
+ r.extraInfoHeader = make([]byte, util.HeaderLen)
+ }
+ header := r.extraInfoHeader
+
+ out := builder.FinishedBytes()
+ size := len(out)
+ binary.BigEndian.PutUint32(header, uint32(size))
+ header[0] = util.RPCExtraInfo
+
+ n, err := c.Write(header)
+ if err != nil {
+ util.WriteErr(n, err)
+ return nil, common.ErrConnClosed
+ }
+
+ n, err = c.Write(out)
+ if err != nil {
+ util.WriteErr(n, err)
+ return nil, common.ErrConnClosed
+ }
+
+ n, err = c.Read(header)
+ if util.ReadErr(n, err, util.HeaderLen) {
+ return nil, common.ErrConnClosed
+ }
+
+ ty := header[0]
+ header[0] = 0
+ length := binary.BigEndian.Uint32(header)
+
+ log.Infof("receive rpc type: %d data length: %d", ty, length)
+
+ buf := make([]byte, length)
+ n, err = c.Read(buf)
+ if util.ReadErr(n, err, int(length)) {
+ return nil, common.ErrConnClosed
+ }
+
+ resp := ei.GetRootAsResp(buf, 0)
+ res := resp.ResultBytes()
+ return res, nil
+}
+
var reqPool = sync.Pool{
New: func() interface{} {
return &Request{}
diff --git a/internal/http/request_test.go b/internal/http/request_test.go
index c14041a..9d56ce8 100644
--- a/internal/http/request_test.go
+++ b/internal/http/request_test.go
@@ -18,16 +18,20 @@
package http
import (
+ "encoding/binary"
"net"
"net/http"
"net/url"
"testing"
- "github.com/apache/apisix-go-plugin-runner/internal/util"
"github.com/api7/ext-plugin-proto/go/A6"
+ ei "github.com/api7/ext-plugin-proto/go/A6/ExtraInfo"
hrc "github.com/api7/ext-plugin-proto/go/A6/HTTPReqCall"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
+
+ "github.com/apache/apisix-go-plugin-runner/internal/util"
+ "github.com/apache/apisix-go-plugin-runner/pkg/common"
)
func getRewriteAction(t *testing.T, b *flatbuffers.Builder) *hrc.Rewrite {
@@ -43,6 +47,17 @@
return nil
}
+func getVarInfo(t *testing.T, req *ei.Req) *ei.Var {
+ tab := &flatbuffers.Table{}
+ if req.Info(tab) {
+ assert.Equal(t, ei.InfoVar, req.InfoType())
+ info := &ei.Var{}
+ info.Init(tab.Bytes, tab.Pos)
+ return info
+ }
+ return nil
+}
+
type pair struct {
name string
value string
@@ -254,3 +269,115 @@
assert.Equal(t, exp, res)
assert.Equal(t, "del", deleted)
}
+
+func TestVar(t *testing.T) {
+ out := buildReq(reqOpt{})
+ r := CreateRequest(out)
+
+ cc, sc := net.Pipe()
+ r.BindConn(cc)
+
+ go func() {
+ header := make([]byte, util.HeaderLen)
+ n, err := sc.Read(header)
+ if util.ReadErr(n, err, util.HeaderLen) {
+ return
+ }
+
+ ty := header[0]
+ assert.Equal(t, byte(util.RPCExtraInfo), ty)
+ header[0] = 0
+ length := binary.BigEndian.Uint32(header)
+
+ buf := make([]byte, length)
+ n, err = sc.Read(buf)
+ if util.ReadErr(n, err, int(length)) {
+ return
+ }
+
+ req := ei.GetRootAsReq(buf, 0)
+ info := getVarInfo(t, req)
+ assert.Equal(t, "request_time", string(info.Name()))
+
+ builder := util.GetBuilder()
+ res := builder.CreateByteVector([]byte("1.0"))
+ ei.RespStart(builder)
+ ei.RespAddResult(builder, res)
+ eiRes := ei.RespEnd(builder)
+ builder.Finish(eiRes)
+ out := builder.FinishedBytes()
+ size := len(out)
+ binary.BigEndian.PutUint32(header, uint32(size))
+ header[0] = util.RPCExtraInfo
+
+ n, err = sc.Write(header)
+ if err != nil {
+ util.WriteErr(n, err)
+ return
+ }
+
+ n, err = sc.Write(out)
+ if err != nil {
+ util.WriteErr(n, err)
+ return
+ }
+ }()
+
+ for i := 0; i < 2; i++ {
+ v, err := r.Var("request_time")
+ assert.Nil(t, err)
+ assert.Equal(t, "1.0", string(v))
+ }
+}
+
+func TestVar_FailedToSendExtraInfoReq(t *testing.T) {
+ out := buildReq(reqOpt{})
+ r := CreateRequest(out)
+
+ cc, sc := net.Pipe()
+ r.BindConn(cc)
+
+ go func() {
+ header := make([]byte, util.HeaderLen)
+ n, err := sc.Read(header)
+ if util.ReadErr(n, err, util.HeaderLen) {
+ return
+ }
+ sc.Close()
+ }()
+
+ _, err := r.Var("request_time")
+ assert.Equal(t, common.ErrConnClosed, err)
+}
+
+func TestVar_FailedToReadExtraInfoResp(t *testing.T) {
+ out := buildReq(reqOpt{})
+ r := CreateRequest(out)
+
+ cc, sc := net.Pipe()
+ r.BindConn(cc)
+
+ go func() {
+ header := make([]byte, util.HeaderLen)
+ n, err := sc.Read(header)
+ if util.ReadErr(n, err, util.HeaderLen) {
+ return
+ }
+
+ ty := header[0]
+ assert.Equal(t, byte(util.RPCExtraInfo), ty)
+ header[0] = 0
+ length := binary.BigEndian.Uint32(header)
+
+ buf := make([]byte, length)
+ n, err = sc.Read(buf)
+ if util.ReadErr(n, err, int(length)) {
+ return
+ }
+
+ sc.Close()
+ }()
+
+ _, err := r.Var("request_time")
+ assert.Equal(t, common.ErrConnClosed, err)
+}
diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go
index c82c7d6..c0e94a9 100644
--- a/internal/plugin/plugin.go
+++ b/internal/plugin/plugin.go
@@ -20,6 +20,7 @@
import (
"errors"
"fmt"
+ "net"
"net/http"
"sync"
@@ -132,8 +133,9 @@
return builder
}
-func HTTPReqCall(buf []byte) (*flatbuffers.Builder, error) {
+func HTTPReqCall(buf []byte, conn net.Conn) (*flatbuffers.Builder, error) {
req := inHTTP.CreateRequest(buf)
+ req.BindConn(conn)
defer inHTTP.ReuseRequest(req)
resp := inHTTP.CreateResponse()
diff --git a/internal/plugin/plugin_test.go b/internal/plugin/plugin_test.go
index e5c860d..9016360 100644
--- a/internal/plugin/plugin_test.go
+++ b/internal/plugin/plugin_test.go
@@ -53,7 +53,7 @@
builder.Finish(r)
out := builder.FinishedBytes()
- b, err := HTTPReqCall(out)
+ b, err := HTTPReqCall(out, nil)
assert.Nil(t, err)
out = b.FinishedBytes()
@@ -87,7 +87,7 @@
builder.Finish(r)
out := builder.FinishedBytes()
- b, err := HTTPReqCall(out)
+ b, err := HTTPReqCall(out, nil)
assert.Nil(t, err)
out = b.FinishedBytes()
diff --git a/internal/server/server.go b/internal/server/server.go
index e1adfef..890997f 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -20,7 +20,6 @@
import (
"encoding/binary"
"fmt"
- "io"
"net"
"os"
"os/signal"
@@ -38,44 +37,14 @@
)
const (
- HeaderLen = 4
- MaxDataSize = 2<<24 - 1
-
SockAddrEnv = "APISIX_LISTEN_ADDRESS"
ConfCacheTTLEnv = "APISIX_CONF_EXPIRE_TIME"
)
-const (
- RPCError = iota
- RPCPrepareConf
- RPCHTTPReqCall
- RPCTest = 127 // used only in test
-)
-
var (
dealRPCTest func(buf []byte) (*flatbuffers.Builder, error)
)
-func readErr(n int, err error, required int) bool {
- if 0 < n && n < required {
- err = fmt.Errorf("truncated, only get the first %d bytes", n)
- }
- if err != nil {
- if err != io.EOF {
- log.Errorf("read: %s", err)
- }
- return true
- }
- return false
-}
-
-func writeErr(n int, err error) {
- if err != nil {
- // TODO: solve "write: broken pipe" with context
- log.Errorf("write: %s", err)
- }
-}
-
func generateErrorReport(err error) (ty byte, out []byte) {
if err == ttlcache.ErrNotFound {
log.Warnf("%s", err)
@@ -83,22 +52,28 @@
log.Errorf("%s", err)
}
- ty = RPCError
+ ty = util.RPCError
bd := ReportError(err)
out = bd.FinishedBytes()
util.PutBuilder(bd)
return
}
-func dispatchRPC(ty byte, in []byte) (byte, []byte) {
+func recoverPanic() {
+ if err := recover(); err != nil {
+ log.Errorf("panic recovered: %s", err)
+ }
+}
+
+func dispatchRPC(ty byte, in []byte, conn net.Conn) (byte, []byte) {
var err error
var bd *flatbuffers.Builder
switch ty {
- case RPCPrepareConf:
+ case util.RPCPrepareConf:
bd, err = plugin.PrepareConf(in)
- case RPCHTTPReqCall:
- bd, err = plugin.HTTPReqCall(in)
- case RPCTest: // Just for test
+ case util.RPCHTTPReqCall:
+ bd, err = plugin.HTTPReqCall(in, conn)
+ case util.RPCTest: // Just for test
bd, err = dealRPCTest(in)
default:
err = UnknownType{ty}
@@ -111,8 +86,8 @@
out = bd.FinishedBytes()
util.PutBuilder(bd)
size := len(out)
- if size > MaxDataSize {
- err = fmt.Errorf("the max length of data is %d but got %d", MaxDataSize, size)
+ if size > util.MaxDataSize {
+ err = fmt.Errorf("the max length of data is %d but got %d", util.MaxDataSize, size)
ty, out = generateErrorReport(err)
}
}
@@ -121,19 +96,15 @@
}
func handleConn(c net.Conn) {
- defer func() {
- if err := recover(); err != nil {
- log.Errorf("panic recovered: %s", err)
- }
- }()
+ defer recoverPanic()
log.Infof("Client connected (%s)", c.RemoteAddr().Network())
defer c.Close()
- header := make([]byte, HeaderLen)
+ header := make([]byte, util.HeaderLen)
for {
n, err := c.Read(header)
- if readErr(n, err, HeaderLen) {
+ if util.ReadErr(n, err, util.HeaderLen) {
break
}
@@ -147,11 +118,11 @@
buf := make([]byte, length)
n, err = c.Read(buf)
- if readErr(n, err, int(length)) {
+ if util.ReadErr(n, err, int(length)) {
break
}
- ty, out := dispatchRPC(ty, buf)
+ ty, out := dispatchRPC(ty, buf, c)
size := len(out)
binary.BigEndian.PutUint32(header, uint32(size))
@@ -159,13 +130,13 @@
n, err = c.Write(header)
if err != nil {
- writeErr(n, err)
+ util.WriteErr(n, err)
break
}
n, err = c.Write(out)
if err != nil {
- writeErr(n, err)
+ util.WriteErr(n, err)
break
}
}
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index ed55fe1..74bf35c 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -53,14 +53,14 @@
}
func TestDispatchRPC_UnknownType(t *testing.T) {
- ty, _ := dispatchRPC(126, []byte(""))
- assert.Equal(t, byte(RPCError), ty)
+ ty, _ := dispatchRPC(126, []byte(""), nil)
+ assert.Equal(t, byte(util.RPCError), ty)
}
func TestDispatchRPC_OutTooLarge(t *testing.T) {
dealRPCTest = func(buf []byte) (*flatbuffers.Builder, error) {
builder := util.GetBuilder()
- bodyVec := builder.CreateByteVector(make([]byte, MaxDataSize+1))
+ bodyVec := builder.CreateByteVector(make([]byte, util.MaxDataSize+1))
hrc.StopStart(builder)
hrc.StopAddBody(builder, bodyVec)
stop := hrc.StopEnd(builder)
@@ -73,8 +73,8 @@
builder.Finish(res)
return builder, nil
}
- ty, _ := dispatchRPC(RPCTest, []byte(""))
- assert.Equal(t, byte(RPCError), ty)
+ ty, _ := dispatchRPC(util.RPCTest, []byte(""), nil)
+ assert.Equal(t, byte(util.RPCError), ty)
}
func TestRun(t *testing.T) {
diff --git a/internal/util/msg.go b/internal/util/msg.go
new file mode 100644
index 0000000..6645410
--- /dev/null
+++ b/internal/util/msg.go
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package util
+
+import (
+ "fmt"
+ "io"
+
+ flatbuffers "github.com/google/flatbuffers/go"
+
+ "github.com/apache/apisix-go-plugin-runner/pkg/log"
+)
+
+const (
+ HeaderLen = 4
+ MaxDataSize = 2<<24 - 1
+)
+
+const (
+ RPCError = iota
+ RPCPrepareConf
+ RPCHTTPReqCall
+ RPCExtraInfo
+ RPCTest = 127 // used only in test
+)
+
+type RPCResult struct {
+ Err error
+ Builder *flatbuffers.Builder
+}
+
+// Use struct if the result is not only []byte
+type ExtraInfoResult []byte
+
+func ReadErr(n int, err error, required int) bool {
+ if 0 < n && n < required {
+ err = fmt.Errorf("truncated, only get the first %d bytes", n)
+ }
+ if err != nil {
+ if err != io.EOF {
+ log.Errorf("read: %s", err)
+ }
+ return true
+ }
+ return false
+}
+
+func WriteErr(n int, err error) {
+ if err != nil {
+ // TODO: solve "write: broken pipe" with context
+ log.Errorf("write: %s", err)
+ }
+}
diff --git a/pkg/common/error.go b/pkg/common/error.go
new file mode 100644
index 0000000..5d50332
--- /dev/null
+++ b/pkg/common/error.go
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package common
+
+import "errors"
+
+var (
+ ErrConnClosed = errors.New("The connection is closed")
+)
diff --git a/pkg/http/http.go b/pkg/http/http.go
index 96f8775..92239eb 100644
--- a/pkg/http/http.go
+++ b/pkg/http/http.go
@@ -51,6 +51,13 @@
Header() Header
// Args returns the query string
Args() url.Values
+
+ // Var returns the value of a Nginx variable, like `r.Var("request_time")`
+ //
+ // To fetch the value, the runner will look up the request's cache first. If not found,
+ // the runner will ask it from the APISIX. If the RPC call is failed, an error in
+ // pkg/common.ErrConnClosed type is returned.
+ Var(name string) ([]byte, error)
}
// Header is like http.Header, but only implements the subset of its methods