blob: a1bff1623b7bb839e02810a40e0cd9854f207413 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package plugin
import (
"errors"
"fmt"
"net"
"net/http"
"sync"
hreqc "github.com/api7/ext-plugin-proto/go/A6/HTTPReqCall"
hrespc "github.com/api7/ext-plugin-proto/go/A6/HTTPRespCall"
flatbuffers "github.com/google/flatbuffers/go"
inHTTP "github.com/apache/apisix-go-plugin-runner/internal/http"
"github.com/apache/apisix-go-plugin-runner/internal/util"
pkgHTTP "github.com/apache/apisix-go-plugin-runner/pkg/http"
"github.com/apache/apisix-go-plugin-runner/pkg/log"
)
type ParseConfFunc func(in []byte) (conf interface{}, err error)
type RequestFilterFunc func(conf interface{}, w http.ResponseWriter, r pkgHTTP.Request)
type ResponseFilterFunc func(conf interface{}, w pkgHTTP.Response)
type pluginOpts struct {
ParseConf ParseConfFunc
RequestFilter RequestFilterFunc
ResponseFilter ResponseFilterFunc
}
type pluginRegistries struct {
sync.Mutex
opts map[string]*pluginOpts
}
type ErrPluginRegistered struct {
name string
}
func (err ErrPluginRegistered) Error() string {
return fmt.Sprintf("plugin %s registered", err.name)
}
var (
pluginRegistry = pluginRegistries{opts: map[string]*pluginOpts{}}
ErrMissingName = errors.New("missing name")
ErrMissingParseConfMethod = errors.New("missing ParseConf method")
ErrMissingRequestFilterMethod = errors.New("missing RequestFilter method")
ErrMissingResponseFilterMethod = errors.New("missing ResponseFilter method")
RequestPhase = requestPhase{}
ResponsePhase = responsePhase{}
)
func RegisterPlugin(name string, pc ParseConfFunc, sv RequestFilterFunc, rsv ResponseFilterFunc) error {
log.Infof("register plugin %s", name)
if name == "" {
return ErrMissingName
}
if pc == nil {
return ErrMissingParseConfMethod
}
if sv == nil {
return ErrMissingRequestFilterMethod
}
if rsv == nil {
return ErrMissingResponseFilterMethod
}
opt := &pluginOpts{
ParseConf: pc,
RequestFilter: sv,
ResponseFilter: rsv,
}
pluginRegistry.Lock()
defer pluginRegistry.Unlock()
if _, found := pluginRegistry.opts[name]; found {
return ErrPluginRegistered{name}
}
pluginRegistry.opts[name] = opt
return nil
}
func findPlugin(name string) *pluginOpts {
if opt, found := pluginRegistry.opts[name]; found {
return opt
}
return nil
}
type requestPhase struct {
}
func (ph *requestPhase) filter(conf RuleConf, w *inHTTP.ReqResponse, r *inHTTP.Request) error {
for _, c := range conf {
plugin := findPlugin(c.Name)
if plugin == nil {
log.Warnf("can't find plugin %s, skip", c.Name)
continue
}
log.Infof("run plugin %s", c.Name)
plugin.RequestFilter(c.Value, w, r)
if w.HasChange() {
// response is generated, no need to continue
break
}
}
return nil
}
func (ph *requestPhase) builder(id uint32, resp *inHTTP.ReqResponse, req *inHTTP.Request) *flatbuffers.Builder {
builder := util.GetBuilder()
if resp != nil && resp.FetchChanges(id, builder) {
return builder
}
if req != nil && req.FetchChanges(id, builder) {
return builder
}
hreqc.RespStart(builder)
hreqc.RespAddId(builder, id)
res := hreqc.RespEnd(builder)
builder.Finish(res)
return builder
}
func HTTPReqCall(buf []byte, conn net.Conn) (*flatbuffers.Builder, error) {
req := inHTTP.CreateRequest(buf)
req.BindConn(conn)
defer inHTTP.ReuseRequest(req)
resp := inHTTP.CreateReqResponse()
defer inHTTP.ReuseReqResponse(resp)
token := req.ConfToken()
conf, err := GetRuleConf(token)
if err != nil {
return nil, err
}
err = RequestPhase.filter(conf, resp, req)
if err != nil {
return nil, err
}
id := req.ID()
builder := RequestPhase.builder(id, resp, req)
return builder, nil
}
type responsePhase struct {
}
func (ph *responsePhase) filter(conf RuleConf, w *inHTTP.Response) error {
for _, c := range conf {
plugin := findPlugin(c.Name)
if plugin == nil {
log.Warnf("can't find plugin %s, skip", c.Name)
continue
}
log.Infof("run plugin %s", c.Name)
plugin.ResponseFilter(c.Value, w)
if w.HasChange() {
// response is generated, no need to continue
break
}
}
return nil
}
func (ph *responsePhase) builder(id uint32, resp *inHTTP.Response) *flatbuffers.Builder {
builder := util.GetBuilder()
if resp != nil && resp.FetchChanges(builder) {
return builder
}
hrespc.RespStart(builder)
hrespc.RespAddId(builder, id)
res := hrespc.RespEnd(builder)
builder.Finish(res)
return builder
}
func HTTPRespCall(buf []byte, conn net.Conn) (*flatbuffers.Builder, error) {
resp := inHTTP.CreateResponse(buf)
resp.BindConn(conn)
defer inHTTP.ReuseResponse(resp)
token := resp.ConfToken()
conf, err := GetRuleConf(token)
if err != nil {
return nil, err
}
err = ResponsePhase.filter(conf, resp)
if err != nil {
return nil, err
}
id := resp.ID()
return ResponsePhase.builder(id, resp), nil
}