blob: 54dfba07d488ad7b73dfe2f858d24b0ee43ae8c0 [file] [log] [blame]
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local http = require("resty.http")
local plugin = require("apisix.plugin")
local ngx = ngx
local ipairs = ipairs
local pairs = pairs
local str_find = core.string.find
local str_lower = string.lower
local plugin_name = "batch-requests"
local default_uri = "/apisix/batch-requests"
local attr_schema = {
type = "object",
properties = {
uri = {
type = "string",
description = "uri for batch-requests",
default = default_uri
}
},
}
local schema = {
type = "object",
}
local default_max_body_size = 1024 * 1024 -- 1MiB
local metadata_schema = {
type = "object",
properties = {
max_body_size = {
description = "max pipeline body size in bytes",
type = "integer",
exclusiveMinimum = 0,
default = default_max_body_size,
},
},
}
local method_schema = core.table.clone(core.schema.method_schema)
method_schema.default = "GET"
local req_schema = {
type = "object",
properties = {
query = {
description = "pipeline query string",
type = "object"
},
headers = {
description = "pipeline header",
type = "object"
},
timeout = {
description = "pipeline timeout(ms)",
type = "integer",
default = 30000,
},
pipeline = {
type = "array",
minItems = 1,
items = {
type = "object",
properties = {
version = {
description = "HTTP version",
type = "number",
enum = {1.0, 1.1},
default = 1.1,
},
method = method_schema,
path = {
type = "string",
minLength = 1,
},
query = {
description = "request header",
type = "object",
},
headers = {
description = "request query string",
type = "object",
},
ssl_verify = {
type = "boolean",
default = false
},
}
}
}
},
anyOf = {
{required = {"pipeline"}},
},
}
local _M = {
version = 0.1,
priority = 4010,
name = plugin_name,
schema = schema,
metadata_schema = metadata_schema,
attr_schema = attr_schema,
scope = "global",
}
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
return core.schema.check(schema, conf)
end
local function check_input(data)
local ok, err = core.schema.check(req_schema, data)
if not ok then
return 400, {error_msg = "bad request body: " .. err}
end
end
local function lowercase_key_or_init(obj)
if not obj then
return {}
end
local lowercase_key_obj = {}
for k, v in pairs(obj) do
lowercase_key_obj[str_lower(k)] = v
end
return lowercase_key_obj
end
local function ensure_header_lowercase(data)
data.headers = lowercase_key_or_init(data.headers)
for i,req in ipairs(data.pipeline) do
req.headers = lowercase_key_or_init(req.headers)
end
end
local function set_common_header(data)
local local_conf = core.config.local_conf()
local real_ip_hdr = core.table.try_read_attr(local_conf, "nginx_config", "http",
"real_ip_header")
-- we don't need to handle '_' to '-' as Nginx won't treat 'X_REAL_IP' as 'X-Real-IP'
real_ip_hdr = str_lower(real_ip_hdr)
local outer_headers = core.request.headers(nil)
for i,req in ipairs(data.pipeline) do
for k, v in pairs(data.headers) do
if not req.headers[k] then
req.headers[k] = v
end
end
if outer_headers then
for k, v in pairs(outer_headers) do
local is_content_header = str_find(k, "content-") == 1
-- skip header start with "content-"
if not req.headers[k] and not is_content_header then
req.headers[k] = v
end
end
end
req.headers[real_ip_hdr] = core.request.get_remote_client_ip()
end
end
local function set_common_query(data)
if not data.query then
return
end
for i,req in ipairs(data.pipeline) do
if not req.query then
req.query = data.query
else
for k, v in pairs(data.query) do
if not req.query[k] then
req.query[k] = v
end
end
end
end
end
local function batch_requests(ctx)
local metadata = plugin.plugin_metadata(plugin_name)
core.log.info("metadata: ", core.json.delay_encode(metadata))
local max_body_size
if metadata then
max_body_size = metadata.value.max_body_size
else
max_body_size = default_max_body_size
end
local req_body, err = core.request.get_body(max_body_size, ctx)
if err then
-- Nginx doesn't support 417: https://trac.nginx.org/nginx/ticket/2062
-- So always return 413 instead
return 413, { error_msg = err }
end
if not req_body then
return 400, {
error_msg = "no request body, you should give at least one pipeline setting"
}
end
local data, err = core.json.decode(req_body)
if not data then
return 400, {
error_msg = "invalid request body: " .. req_body .. ", err: " .. err
}
end
local code, body = check_input(data)
if code then
return code, body
end
local httpc = http.new()
httpc:set_timeout(data.timeout)
local ok, err = httpc:connect("127.0.0.1", ngx.var.server_port)
if not ok then
return 500, {error_msg = "connect to apisix failed: " .. err}
end
ensure_header_lowercase(data)
set_common_header(data)
set_common_query(data)
local responses, err = httpc:request_pipeline(data.pipeline)
if not responses then
return 400, {error_msg = "request failed: " .. err}
end
local aggregated_resp = {}
for _, resp in ipairs(responses) do
if not resp.status then
core.table.insert(aggregated_resp, {
status = 504,
reason = "upstream timeout"
})
end
local sub_resp = {
status = resp.status,
reason = resp.reason,
headers = resp.headers,
}
if resp.has_body then
sub_resp.body = resp:read_body()
end
core.table.insert(aggregated_resp, sub_resp)
end
return 200, aggregated_resp
end
function _M.api()
local uri = default_uri
local attr = plugin.plugin_attr(plugin_name)
if attr then
uri = attr.uri or default_uri
end
return {
{
methods = {"POST"},
uri = uri,
handler = batch_requests,
}
}
end
return _M