blob: e34a8527c9d74aea4fa87dd187d1a0eba063513b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package data_loader
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"path"
"reflect"
"regexp"
"strings"
"github.com/getkin/kin-openapi/openapi3"
"github.com/gin-gonic/gin"
"github.com/shiningrush/droplet"
"github.com/shiningrush/droplet/data"
"github.com/shiningrush/droplet/wrapper"
wgin "github.com/shiningrush/droplet/wrapper/gin"
"github.com/apisix/manager-api/internal/conf"
"github.com/apisix/manager-api/internal/core/entity"
"github.com/apisix/manager-api/internal/core/store"
"github.com/apisix/manager-api/internal/handler"
"github.com/apisix/manager-api/internal/log"
"github.com/apisix/manager-api/internal/utils"
"github.com/apisix/manager-api/internal/utils/consts"
)
type ImportHandler struct {
routeStore *store.GenericStore
svcStore store.Interface
upstreamStore store.Interface
}
func NewImportHandler() (handler.RouteRegister, error) {
return &ImportHandler{
routeStore: store.GetStore(store.HubKeyRoute),
svcStore: store.GetStore(store.HubKeyService),
upstreamStore: store.GetStore(store.HubKeyUpstream),
}, nil
}
var regPathVar = regexp.MustCompile(`{[\w.]*}`)
var regPathRepeat = regexp.MustCompile(`-APISIX-REPEAT-URI-[\d]*`)
func (h *ImportHandler) ApplyRoute(r *gin.Engine) {
r.POST("/apisix/admin/import/routes", wgin.Wraps(h.Import,
wrapper.InputType(reflect.TypeOf(ImportInput{}))))
}
type ImportInput struct {
Force bool `auto_read:"force,query"`
FileName string `auto_read:"_file"`
FileContent []byte `auto_read:"file"`
}
func (h *ImportHandler) Import(c droplet.Context) (interface{}, error) {
input := c.Input().(*ImportInput)
Force := input.Force
// file check
suffix := path.Ext(input.FileName)
if suffix != ".json" && suffix != ".yaml" && suffix != ".yml" {
return nil, fmt.Errorf("required file type is .yaml, .yml or .json but got: %s", suffix)
}
contentLen := bytes.Count(input.FileContent, nil) - 1
if contentLen > conf.ImportSizeLimit {
log.Warnf("upload file size exceeds limit: %d", contentLen)
return nil, fmt.Errorf("the file size exceeds the limit; limit %d", conf.ImportSizeLimit)
}
swagger, err := openapi3.NewSwaggerLoader().LoadSwaggerFromData(input.FileContent)
if err != nil {
return nil, err
}
if len(swagger.Paths) < 1 {
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest},
consts.ErrImportFile
}
routes, err := OpenAPI3ToRoute(swagger)
if err != nil {
return nil, err
}
// check route
for _, route := range routes {
err := checkRouteExist(c.Context(), h.routeStore, route)
if err != nil && !Force {
log.Warnf("import duplicate: %s, route: %#v", err, route)
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest},
fmt.Errorf("route(uris:%v) conflict, %s", route.Uris, err)
}
if route.ServiceID != nil {
_, err := h.svcStore.Get(c.Context(), utils.InterfaceToString(route.ServiceID))
if err != nil {
if err == data.ErrNotFound {
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest},
fmt.Errorf(consts.IDNotFound, "service", route.ServiceID)
}
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, err
}
}
if route.UpstreamID != nil {
_, err := h.upstreamStore.Get(c.Context(), utils.InterfaceToString(route.UpstreamID))
if err != nil {
if err == data.ErrNotFound {
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest},
fmt.Errorf(consts.IDNotFound, "upstream", route.UpstreamID)
}
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, err
}
}
if _, err := h.routeStore.CreateCheck(route); err != nil {
return handler.SpecCodeResponse(err),
fmt.Errorf("create route(uris:%v) failed: %s", route.Uris, err)
}
}
// create route
for _, route := range routes {
if Force && route.ID != nil {
if _, err := h.routeStore.Update(c.Context(), route, true); err != nil {
return handler.SpecCodeResponse(err),
fmt.Errorf("update route(uris:%v) failed: %s", route.Uris, err)
}
} else {
if _, err := h.routeStore.Create(c.Context(), route); err != nil {
return handler.SpecCodeResponse(err),
fmt.Errorf("create route(uris:%v) failed: %s", route.Uris, err)
}
}
}
return map[string]int{
"paths": len(swagger.Paths),
"routes": len(routes),
}, nil
}
func checkRouteExist(ctx context.Context, routeStore *store.GenericStore, route *entity.Route) error {
//routeStore := store.GetStore(store.HubKeyRoute)
ret, err := routeStore.List(ctx, store.ListInput{
Predicate: func(obj interface{}) bool {
id := utils.InterfaceToString(route.ID)
item := obj.(*entity.Route)
if id != "" && id != utils.InterfaceToString(item.ID) {
return false
}
if !(item.Host == route.Host && item.URI == route.URI && utils.StringSliceEqual(item.Uris, route.Uris) &&
utils.StringSliceEqual(item.RemoteAddrs, route.RemoteAddrs) && item.RemoteAddr == route.RemoteAddr &&
utils.StringSliceEqual(item.Hosts, route.Hosts) && item.Priority == route.Priority &&
utils.ValueEqual(item.Vars, route.Vars) && item.FilterFunc == route.FilterFunc) {
return false
}
return true
},
PageSize: 0,
PageNumber: 0,
})
if err != nil {
return err
}
if len(ret.Rows) > 0 {
return consts.InvalidParam("route is duplicate")
}
return nil
}
func parseExtension(val *openapi3.Operation) (*entity.Route, error) {
routeMap := map[string]interface{}{}
for key, val := range val.Extensions {
if strings.HasPrefix(key, "x-apisix-") {
routeMap[strings.TrimPrefix(key, "x-apisix-")] = val
}
}
route := new(entity.Route)
routeJson, err := json.Marshal(routeMap)
if err != nil {
return nil, err
}
err = json.Unmarshal(routeJson, &route)
if err != nil {
return nil, err
}
return route, nil
}
type PathValue struct {
Method string
Value *openapi3.Operation
}
func mergePathValue(key string, values []PathValue, swagger *openapi3.Swagger) (map[string]*entity.Route, error) {
var parsed []PathValue
var routes = map[string]*entity.Route{}
for _, value := range values {
value.Value.OperationID = strings.Replace(value.Value.OperationID, value.Method, "", 1)
var eq = false
for _, v := range parsed {
if utils.ValueEqual(v.Value, value.Value) {
eq = true
if routes[v.Method].Methods == nil {
routes[v.Method].Methods = []string{}
}
routes[v.Method].Methods = append(routes[v.Method].Methods, value.Method)
}
}
// not equal to the previous ones
if !eq {
route, err := getRouteFromPaths(value.Method, key, value.Value, swagger)
if err != nil {
return nil, err
}
routes[value.Method] = route
parsed = append(parsed, value)
}
}
return routes, nil
}
func OpenAPI3ToRoute(swagger *openapi3.Swagger) ([]*entity.Route, error) {
var routes []*entity.Route
paths := swagger.Paths
var upstream *entity.UpstreamDef
var err error
for k, v := range paths {
k = regPathRepeat.ReplaceAllString(k, "")
upstream = &entity.UpstreamDef{}
if up, ok := v.Extensions["x-apisix-upstream"]; ok {
err = json.Unmarshal(up.(json.RawMessage), upstream)
if err != nil {
return nil, err
}
}
var values []PathValue
if v.Get != nil {
value := PathValue{
Method: http.MethodGet,
Value: v.Get,
}
values = append(values, value)
}
if v.Post != nil {
value := PathValue{
Method: http.MethodPost,
Value: v.Post,
}
values = append(values, value)
}
if v.Head != nil {
value := PathValue{
Method: http.MethodHead,
Value: v.Head,
}
values = append(values, value)
}
if v.Put != nil {
value := PathValue{
Method: http.MethodPut,
Value: v.Put,
}
values = append(values, value)
}
if v.Patch != nil {
value := PathValue{
Method: http.MethodPatch,
Value: v.Patch,
}
values = append(values, value)
}
if v.Delete != nil {
value := PathValue{
Method: http.MethodDelete,
Value: v.Delete,
}
values = append(values, value)
}
// merge same route
tmp, err := mergePathValue(k, values, swagger)
if err != nil {
return nil, err
}
for _, route := range tmp {
routes = append(routes, route)
}
}
return routes, nil
}
func parseParameters(parameters openapi3.Parameters, plugins map[string]interface{}) {
props := make(map[string]interface{})
var required []string
for _, v := range parameters {
if v.Value.Schema != nil {
v.Value.Schema.Value.Format = ""
v.Value.Schema.Value.XML = nil
}
switch v.Value.In {
case "header":
if v.Value.Schema != nil && v.Value.Schema.Value != nil {
props[v.Value.Name] = v.Value.Schema.Value
}
if v.Value.Required {
required = append(required, v.Value.Name)
}
}
}
requestValidation := make(map[string]interface{})
if rv, ok := plugins["request-validation"]; ok {
requestValidation = rv.(map[string]interface{})
}
requestValidation["header_schema"] = &entity.RequestValidation{
Type: "object",
Required: required,
Properties: props,
}
plugins["request-validation"] = requestValidation
}
func parseRequestBody(requestBody *openapi3.RequestBodyRef, swagger *openapi3.Swagger, plugins map[string]interface{}) {
schema := requestBody.Value.Content
requestValidation := make(map[string]interface{})
if rv, ok := plugins["request-validation"]; ok {
requestValidation = rv.(map[string]interface{})
}
for _, v := range schema {
if v.Schema.Ref != "" {
s := getParameters(v.Schema.Ref, &swagger.Components).Value
requestValidation["body_schema"] = &entity.RequestValidation{
Type: s.Type,
Required: s.Required,
Properties: s.Properties,
}
plugins["request-validation"] = requestValidation
} else if v.Schema.Value != nil {
if v.Schema.Value.Properties != nil {
for k1, v1 := range v.Schema.Value.Properties {
if v1.Ref != "" {
s := getParameters(v1.Ref, &swagger.Components)
v.Schema.Value.Properties[k1] = s
}
v1.Value.Format = ""
}
requestValidation["body_schema"] = &entity.RequestValidation{
Type: v.Schema.Value.Type,
Required: v.Schema.Value.Required,
Properties: v.Schema.Value.Properties,
}
plugins["request-validation"] = requestValidation
} else if v.Schema.Value.Items != nil {
if v.Schema.Value.Items.Ref != "" {
s := getParameters(v.Schema.Value.Items.Ref, &swagger.Components).Value
requestValidation["body_schema"] = &entity.RequestValidation{
Type: s.Type,
Required: s.Required,
Properties: s.Properties,
}
plugins["request-validation"] = requestValidation
}
} else {
requestValidation["body_schema"] = &entity.RequestValidation{
Type: "object",
Required: []string{},
Properties: v.Schema.Value.Properties,
}
}
}
plugins["request-validation"] = requestValidation
}
}
func parseSecurity(security openapi3.SecurityRequirements, securitySchemes openapi3.SecuritySchemes, plugins map[string]interface{}) {
// todo: import consumers
for _, securities := range security {
for name := range securities {
if schema, ok := securitySchemes[name]; ok {
value := schema.Value
if value == nil {
continue
}
// basic auth
if value.Type == "http" && value.Scheme == "basic" {
plugins["basic-auth"] = map[string]interface{}{}
//username, ok := value.Extensions["username"]
//if !ok {
// continue
//}
//password, ok := value.Extensions["password"]
//if !ok {
// continue
//}
//plugins["basic-auth"] = map[string]interface{}{
// "username": username,
// "password": password,
//}
// jwt auth
} else if value.Type == "http" && value.Scheme == "bearer" && value.BearerFormat == "JWT" {
plugins["jwt-auth"] = map[string]interface{}{}
//key, ok := value.Extensions["key"]
//if !ok {
// continue
//}
//secret, ok := value.Extensions["secret"]
//if !ok {
// continue
//}
//plugins["jwt-auth"] = map[string]interface{}{
// "key": key,
// "secret": secret,
//}
// key auth
} else if value.Type == "apiKey" {
plugins["key-auth"] = map[string]interface{}{}
//key, ok := value.Extensions["key"]
//if !ok {
// continue
//}
//plugins["key-auth"] = map[string]interface{}{
// "key": key,
//}
}
}
}
}
}
func getRouteFromPaths(method, key string, value *openapi3.Operation, swagger *openapi3.Swagger) (*entity.Route, error) {
// transform /path/{var} to /path/*
foundStr := regPathVar.FindString(key)
if foundStr != "" {
key = strings.Split(key, foundStr)[0] + "*"
}
route, err := parseExtension(value)
if err != nil {
return nil, err
}
route.Uris = []string{key}
route.Name = value.OperationID
route.Desc = value.Summary
route.Methods = []string{method}
if route.Plugins == nil {
route.Plugins = make(map[string]interface{})
}
if value.Parameters != nil {
parseParameters(value.Parameters, route.Plugins)
}
if value.RequestBody != nil {
parseRequestBody(value.RequestBody, swagger, route.Plugins)
}
if value.Security != nil && swagger.Components.SecuritySchemes != nil {
parseSecurity(*value.Security, swagger.Components.SecuritySchemes, route.Plugins)
}
return route, nil
}
func getParameters(ref string, components *openapi3.Components) *openapi3.SchemaRef {
schemaRef := &openapi3.SchemaRef{}
arr := strings.Split(ref, "/")
if arr[0] == "#" && arr[1] == "components" && arr[2] == "schemas" {
schemaRef = components.Schemas[arr[3]]
schemaRef.Value.XML = nil
// traverse properties to find another ref
for k, v := range schemaRef.Value.Properties {
if v.Value != nil {
v.Value.XML = nil
v.Value.Format = ""
}
if v.Ref != "" {
schemaRef.Value.Properties[k] = getParameters(v.Ref, components)
} else if v.Value.Items != nil && v.Value.Items.Ref != "" {
v.Value.Items = getParameters(v.Value.Items.Ref, components)
} else if v.Value.Items != nil && v.Value.Items.Value != nil {
v.Value.Items.Value.XML = nil
v.Value.Items.Value.Format = ""
}
}
}
return schemaRef
}