blob: e052160dc95763cdabd91a30a585ca5fbe7ff825 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package triple
import (
"errors"
"net/http"
"net/url"
"strings"
"time"
)
import (
"github.com/dubbogo/gost/log/logger"
)
import (
"dubbo.apache.org/dubbo-go/v3/global"
)
// TODO: The triple options for the server and client are mixed together now.
// We need to find a way to separate them later.
type Options struct {
Triple *global.TripleConfig
}
func defaultOptions() *Options {
return &Options{Triple: global.DefaultTripleConfig()}
}
func NewOptions(opts ...Option) *Options {
defSrvOpts := defaultOptions()
for _, opt := range opts {
opt(defSrvOpts)
}
return defSrvOpts
}
type Option func(*Options)
// WithKeepAlive sets the keep-alive interval and timeout for the Triple protocol.
// interval: The duration between keep-alive pings.
// timeout: The duration to wait for a keep-alive response before considering the connection dead.
// If not set, default interval is 10s, default timeout is 20s.
func WithKeepAlive(interval, timeout time.Duration) Option {
return func(opts *Options) {
opts.Triple.KeepAliveInterval = interval.String()
opts.Triple.KeepAliveTimeout = timeout.String()
}
}
// WithKeepAliveInterval sets the keep-alive interval for the Triple protocol.
// interval: The duration between keep-alive pings.
// If not set, default interval is 10s.
func WithKeepAliveInterval(interval time.Duration) Option {
return func(opts *Options) {
opts.Triple.KeepAliveInterval = interval.String()
}
}
// WithKeepAliveTimeout sets the keep-alive timeout for the Triple protocol.
// timeout: The duration to wait for a keep-alive response before considering the connection dead.
// If not set, default timeout is 20s.
func WithKeepAliveTimeout(timeout time.Duration) Option {
return func(opts *Options) {
opts.Triple.KeepAliveTimeout = timeout.String()
}
}
// WithMaxServerSendMsgSize sets the maximum size of messages that the server can send.
// size: The maximum message size in bytes, specified as a string (e.g., "4MB").
// If not set, default value is 2147MB (math.MaxInt32).
func WithMaxServerSendMsgSize(size string) Option {
return func(opts *Options) {
opts.Triple.MaxServerSendMsgSize = size
}
}
// WithMaxServerRecvMsgSize sets the maximum size of messages that the server can receive.
// size: The maximum message size in bytes, specified as a string (e.g., "4MB").
// If not set, default value is 4MB (4194304 bytes).
func WithMaxServerRecvMsgSize(size string) Option {
return func(opts *Options) {
opts.Triple.MaxServerRecvMsgSize = size
}
}
// WithCORS applies CORS configuration to triple options.
// Invalid configs are logged as errors and ignored (no-op).
func WithCORS(opts ...CORSOption) Option {
cors := global.DefaultCorsConfig()
for _, opt := range opts {
opt(cors)
}
if err := validateCorsConfig(cors); err != nil {
logger.Errorf("[TRIPLE] invalid CORS config: %v", err)
// Return a no-op function to ignore invalid CORS configuration
return func(*Options) {}
}
return func(opts *Options) {
opts.Triple.Cors = cors
}
}
// Http3Enable enables HTTP/3 support for the Triple protocol.
// This option configures the server to start both HTTP/2 and HTTP/3 servers
// simultaneously, providing modern HTTP/3 capabilities alongside traditional HTTP/2.
//
// When enabled, the server will:
// - Start an HTTP/3 server using QUIC protocol
// - Continue running the existing HTTP/2 server
// - Enable protocol negotiation between HTTP/2 and HTTP/3
// - Provide improved performance and security benefits of HTTP/3
//
// Usage Examples:
//
// // Basic HTTP/3 enablement
// server := triple.NewServer(
// triple.Http3Enable(),
// )
//
// Requirements:
// - TLS configuration is required for HTTP/3
// - Server must have valid TLS certificates
// - Clients must support HTTP/3 for full benefits
// - Fallback to HTTP/2 is automatic for unsupported clients
//
// Default Behavior:
// - HTTP/3 is disabled by default for backward compatibility
// - When enabled, negotiation defaults to true
// - Both HTTP/2 and HTTP/3 servers run on the same port
//
// # Experimental
//
// NOTICE: This API is EXPERIMENTAL and may be changed or removed in
// a later release.
func Http3Enable() Option {
return func(opts *Options) {
opts.Triple.Http3.Enable = true
}
}
// Http3Negotiation configures HTTP/3 negotiation behavior for the Triple protocol.
// This option controls whether HTTP/2 Alternative Services (Alt-Svc) negotiation
// is enabled when both HTTP/2 and HTTP/3 servers are running simultaneously.
//
// Usage Examples:
//
// // Enable HTTP/3 negotiation (default behavior)
// server := triple.NewServer(
// triple.Http3Enable(),
// triple.Http3Negotiation(true),
// )
//
// // Disable HTTP/3 negotiation for explicit protocol control
// server := triple.NewServer(
// triple.Http3Enable(),
// triple.Http3Negotiation(false),
// )
//
// Default Behavior:
// - When HTTP/3 is enabled, negotiation defaults to true
// - This ensures backward compatibility and optimal client experience
//
// # Experimental
//
// NOTICE: This API is EXPERIMENTAL and may be changed or removed in
// a later release.
func Http3Negotiation(negotiation bool) Option {
return func(opts *Options) {
opts.Triple.Http3.Negotiation = negotiation
}
}
// CORSOption configures a single aspect of CORS.
type CORSOption func(*global.CorsConfig)
// CORSAllowOrigins sets allowed origins for CORS requests.
func CORSAllowOrigins(origins ...string) CORSOption {
return func(c *global.CorsConfig) {
c.AllowOrigins = append([]string(nil), origins...)
}
}
// CORSAllowMethods sets allowed HTTP methods for CORS requests.
func CORSAllowMethods(methods ...string) CORSOption {
return func(c *global.CorsConfig) {
c.AllowMethods = append([]string(nil), methods...)
}
}
// CORSAllowHeaders sets allowed request headers for CORS requests.
func CORSAllowHeaders(headers ...string) CORSOption {
return func(c *global.CorsConfig) {
c.AllowHeaders = append([]string(nil), headers...)
}
}
// CORSExposeHeaders sets headers exposed to the browser.
func CORSExposeHeaders(headers ...string) CORSOption {
return func(c *global.CorsConfig) {
c.ExposeHeaders = append([]string(nil), headers...)
}
}
// CORSAllowCredentials toggles whether credentials are allowed.
func CORSAllowCredentials(allow bool) CORSOption {
return func(c *global.CorsConfig) {
c.AllowCredentials = allow
}
}
// CORSMaxAge sets the max age for preflight cache.
func CORSMaxAge(maxAge int) CORSOption {
return func(c *global.CorsConfig) {
c.MaxAge = maxAge
}
}
var validHTTPMethods = map[string]bool{
http.MethodGet: true,
http.MethodHead: true,
http.MethodPost: true,
http.MethodPut: true,
http.MethodPatch: true,
http.MethodDelete: true,
http.MethodConnect: true,
http.MethodOptions: true,
http.MethodTrace: true,
}
// validateCorsConfig validates CORS configuration.
func validateCorsConfig(cors *global.CorsConfig) error {
if cors == nil {
return nil
}
// Validate origins
for _, origin := range cors.AllowOrigins {
if origin == "" {
return errors.New("allow-origins cannot contain empty string")
}
if err := validateOrigin(origin); err != nil {
return err
}
if cors.AllowCredentials && origin == "*" {
return errors.New("allowCredentials cannot be true when allow-origins contains \"*\"")
}
}
// Validate methods
for _, method := range cors.AllowMethods {
if method == "" || !validHTTPMethods[strings.ToUpper(method)] {
return errors.New("allow-methods contains invalid HTTP method")
}
}
// Validate headers (both allow and expose)
if err := validateHeaders(cors.AllowHeaders, "allow-headers"); err != nil {
return err
}
if err := validateHeaders(cors.ExposeHeaders, "expose-headers"); err != nil {
return err
}
if cors.MaxAge < 0 {
return errors.New("max-age cannot be negative")
}
return nil
}
func validateHeaders(headers []string, fieldName string) error {
for _, header := range headers {
if strings.TrimSpace(header) == "" {
return errors.New(fieldName + " cannot contain empty string")
}
}
return nil
}
func validateOrigin(origin string) error {
// Allow wildcard
if origin == "*" {
return nil
}
// Check for whitespace
if strings.ContainsAny(origin, " \t\n\r") {
return errors.New("origin contains whitespace")
}
// Handle subdomain wildcard (*.example.com or https://*.example.com)
if strings.Contains(origin, "*") {
return validateWildcardOrigin(origin)
}
// Validate URL format
if strings.Contains(origin, "://") {
u, err := url.Parse(origin)
if err != nil || u.Scheme == "" || u.Host == "" {
return errors.New("invalid URL format")
}
}
return nil
}
func validateWildcardOrigin(origin string) error {
// Must be *.domain or scheme://*.domain
if !strings.HasPrefix(origin, "*.") && !strings.Contains(origin, "://*.") {
return errors.New("wildcard must be at start: '*.domain' or 'scheme://*.domain'")
}
// Extract domain part after *.
var domain string
if strings.Contains(origin, "://*.") {
parts := strings.SplitN(origin, "://*.", 2)
if len(parts) != 2 || parts[1] == "" {
return errors.New("invalid subdomain wildcard format")
}
domain = parts[1]
} else {
domain = origin[2:]
if domain == "" {
return errors.New("invalid subdomain wildcard format")
}
}
// Only single wildcard allowed
if strings.Contains(domain, "*") {
return errors.New("only single wildcard at subdomain level is allowed")
}
return nil
}
type OpenAPIOption func(*global.OpenAPIConfig)
func WithOpenAPI(opts ...OpenAPIOption) Option {
openapi := global.DefaultOpenAPIConfig()
for _, opt := range opts {
opt(openapi)
}
return func(opts *Options) {
opts.Triple.OpenAPI = openapi
}
}
func OpenAPIEnable() OpenAPIOption {
return func(o *global.OpenAPIConfig) {
o.Enabled = true
}
}
func OpenAPIInfoTitle(title string) OpenAPIOption {
return func(o *global.OpenAPIConfig) {
o.InfoTitle = title
}
}
func OpenAPIInfoVersion(version string) OpenAPIOption {
return func(o *global.OpenAPIConfig) {
o.InfoVersion = version
}
}
func OpenAPIInfoDescription(description string) OpenAPIOption {
return func(o *global.OpenAPIConfig) {
o.InfoDescription = description
}
}
func OpenAPIDefaultConsumesMediaTypes(types ...string) OpenAPIOption {
return func(o *global.OpenAPIConfig) {
o.DefaultConsumesMediaTypes = append(o.DefaultConsumesMediaTypes, types...)
}
}
func OpenAPIDefaultProducesMediaTypes(types ...string) OpenAPIOption {
return func(o *global.OpenAPIConfig) {
o.DefaultProducesMediaTypes = append(o.DefaultProducesMediaTypes, types...)
}
}
func OpenAPIDefaultHttpStatusCodes(codes ...string) OpenAPIOption {
return func(o *global.OpenAPIConfig) {
o.DefaultHttpStatusCodes = append(o.DefaultHttpStatusCodes, codes...)
}
}
func OpenAPIPath(path string) OpenAPIOption {
return func(o *global.OpenAPIConfig) {
o.Path = path
}
}
func OpenAPISettings(settings map[string]string) OpenAPIOption {
return func(o *global.OpenAPIConfig) {
if o.Settings == nil {
o.Settings = make(map[string]string)
}
for k, v := range settings {
o.Settings[k] = v
}
}
}