| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package triple |
| |
| import ( |
| "errors" |
| "net/http" |
| "net/url" |
| "strings" |
| "time" |
| ) |
| |
| import ( |
| "github.com/dubbogo/gost/log/logger" |
| ) |
| |
| import ( |
| "dubbo.apache.org/dubbo-go/v3/global" |
| ) |
| |
| // TODO: The triple options for the server and client are mixed together now. |
| // We need to find a way to separate them later. |
| type Options struct { |
| Triple *global.TripleConfig |
| } |
| |
| func defaultOptions() *Options { |
| return &Options{Triple: global.DefaultTripleConfig()} |
| } |
| |
| func NewOptions(opts ...Option) *Options { |
| defSrvOpts := defaultOptions() |
| for _, opt := range opts { |
| opt(defSrvOpts) |
| } |
| return defSrvOpts |
| } |
| |
| type Option func(*Options) |
| |
| // WithKeepAlive sets the keep-alive interval and timeout for the Triple protocol. |
| // interval: The duration between keep-alive pings. |
| // timeout: The duration to wait for a keep-alive response before considering the connection dead. |
| // If not set, default interval is 10s, default timeout is 20s. |
| func WithKeepAlive(interval, timeout time.Duration) Option { |
| return func(opts *Options) { |
| opts.Triple.KeepAliveInterval = interval.String() |
| opts.Triple.KeepAliveTimeout = timeout.String() |
| } |
| } |
| |
| // WithKeepAliveInterval sets the keep-alive interval for the Triple protocol. |
| // interval: The duration between keep-alive pings. |
| // If not set, default interval is 10s. |
| func WithKeepAliveInterval(interval time.Duration) Option { |
| return func(opts *Options) { |
| opts.Triple.KeepAliveInterval = interval.String() |
| } |
| } |
| |
| // WithKeepAliveTimeout sets the keep-alive timeout for the Triple protocol. |
| // timeout: The duration to wait for a keep-alive response before considering the connection dead. |
| // If not set, default timeout is 20s. |
| func WithKeepAliveTimeout(timeout time.Duration) Option { |
| return func(opts *Options) { |
| opts.Triple.KeepAliveTimeout = timeout.String() |
| } |
| } |
| |
| // WithMaxServerSendMsgSize sets the maximum size of messages that the server can send. |
| // size: The maximum message size in bytes, specified as a string (e.g., "4MB"). |
| // If not set, default value is 2147MB (math.MaxInt32). |
| func WithMaxServerSendMsgSize(size string) Option { |
| return func(opts *Options) { |
| opts.Triple.MaxServerSendMsgSize = size |
| } |
| } |
| |
| // WithMaxServerRecvMsgSize sets the maximum size of messages that the server can receive. |
| // size: The maximum message size in bytes, specified as a string (e.g., "4MB"). |
| // If not set, default value is 4MB (4194304 bytes). |
| func WithMaxServerRecvMsgSize(size string) Option { |
| return func(opts *Options) { |
| opts.Triple.MaxServerRecvMsgSize = size |
| } |
| } |
| |
| // WithCORS applies CORS configuration to triple options. |
| // Invalid configs are logged as errors and ignored (no-op). |
| func WithCORS(opts ...CORSOption) Option { |
| cors := global.DefaultCorsConfig() |
| for _, opt := range opts { |
| opt(cors) |
| } |
| if err := validateCorsConfig(cors); err != nil { |
| logger.Errorf("[TRIPLE] invalid CORS config: %v", err) |
| // Return a no-op function to ignore invalid CORS configuration |
| return func(*Options) {} |
| } |
| return func(opts *Options) { |
| opts.Triple.Cors = cors |
| } |
| } |
| |
| // Http3Enable enables HTTP/3 support for the Triple protocol. |
| // This option configures the server to start both HTTP/2 and HTTP/3 servers |
| // simultaneously, providing modern HTTP/3 capabilities alongside traditional HTTP/2. |
| // |
| // When enabled, the server will: |
| // - Start an HTTP/3 server using QUIC protocol |
| // - Continue running the existing HTTP/2 server |
| // - Enable protocol negotiation between HTTP/2 and HTTP/3 |
| // - Provide improved performance and security benefits of HTTP/3 |
| // |
| // Usage Examples: |
| // |
| // // Basic HTTP/3 enablement |
| // server := triple.NewServer( |
| // triple.Http3Enable(), |
| // ) |
| // |
| // Requirements: |
| // - TLS configuration is required for HTTP/3 |
| // - Server must have valid TLS certificates |
| // - Clients must support HTTP/3 for full benefits |
| // - Fallback to HTTP/2 is automatic for unsupported clients |
| // |
| // Default Behavior: |
| // - HTTP/3 is disabled by default for backward compatibility |
| // - When enabled, negotiation defaults to true |
| // - Both HTTP/2 and HTTP/3 servers run on the same port |
| // |
| // # Experimental |
| // |
| // NOTICE: This API is EXPERIMENTAL and may be changed or removed in |
| // a later release. |
| func Http3Enable() Option { |
| return func(opts *Options) { |
| opts.Triple.Http3.Enable = true |
| } |
| } |
| |
| // Http3Negotiation configures HTTP/3 negotiation behavior for the Triple protocol. |
| // This option controls whether HTTP/2 Alternative Services (Alt-Svc) negotiation |
| // is enabled when both HTTP/2 and HTTP/3 servers are running simultaneously. |
| // |
| // Usage Examples: |
| // |
| // // Enable HTTP/3 negotiation (default behavior) |
| // server := triple.NewServer( |
| // triple.Http3Enable(), |
| // triple.Http3Negotiation(true), |
| // ) |
| // |
| // // Disable HTTP/3 negotiation for explicit protocol control |
| // server := triple.NewServer( |
| // triple.Http3Enable(), |
| // triple.Http3Negotiation(false), |
| // ) |
| // |
| // Default Behavior: |
| // - When HTTP/3 is enabled, negotiation defaults to true |
| // - This ensures backward compatibility and optimal client experience |
| // |
| // # Experimental |
| // |
| // NOTICE: This API is EXPERIMENTAL and may be changed or removed in |
| // a later release. |
| func Http3Negotiation(negotiation bool) Option { |
| return func(opts *Options) { |
| opts.Triple.Http3.Negotiation = negotiation |
| } |
| } |
| |
| // CORSOption configures a single aspect of CORS. |
| type CORSOption func(*global.CorsConfig) |
| |
| // CORSAllowOrigins sets allowed origins for CORS requests. |
| func CORSAllowOrigins(origins ...string) CORSOption { |
| return func(c *global.CorsConfig) { |
| c.AllowOrigins = append([]string(nil), origins...) |
| } |
| } |
| |
| // CORSAllowMethods sets allowed HTTP methods for CORS requests. |
| func CORSAllowMethods(methods ...string) CORSOption { |
| return func(c *global.CorsConfig) { |
| c.AllowMethods = append([]string(nil), methods...) |
| } |
| } |
| |
| // CORSAllowHeaders sets allowed request headers for CORS requests. |
| func CORSAllowHeaders(headers ...string) CORSOption { |
| return func(c *global.CorsConfig) { |
| c.AllowHeaders = append([]string(nil), headers...) |
| } |
| } |
| |
| // CORSExposeHeaders sets headers exposed to the browser. |
| func CORSExposeHeaders(headers ...string) CORSOption { |
| return func(c *global.CorsConfig) { |
| c.ExposeHeaders = append([]string(nil), headers...) |
| } |
| } |
| |
| // CORSAllowCredentials toggles whether credentials are allowed. |
| func CORSAllowCredentials(allow bool) CORSOption { |
| return func(c *global.CorsConfig) { |
| c.AllowCredentials = allow |
| } |
| } |
| |
| // CORSMaxAge sets the max age for preflight cache. |
| func CORSMaxAge(maxAge int) CORSOption { |
| return func(c *global.CorsConfig) { |
| c.MaxAge = maxAge |
| } |
| } |
| |
| var validHTTPMethods = map[string]bool{ |
| http.MethodGet: true, |
| http.MethodHead: true, |
| http.MethodPost: true, |
| http.MethodPut: true, |
| http.MethodPatch: true, |
| http.MethodDelete: true, |
| http.MethodConnect: true, |
| http.MethodOptions: true, |
| http.MethodTrace: true, |
| } |
| |
| // validateCorsConfig validates CORS configuration. |
| func validateCorsConfig(cors *global.CorsConfig) error { |
| if cors == nil { |
| return nil |
| } |
| |
| // Validate origins |
| for _, origin := range cors.AllowOrigins { |
| if origin == "" { |
| return errors.New("allow-origins cannot contain empty string") |
| } |
| if err := validateOrigin(origin); err != nil { |
| return err |
| } |
| if cors.AllowCredentials && origin == "*" { |
| return errors.New("allowCredentials cannot be true when allow-origins contains \"*\"") |
| } |
| } |
| |
| // Validate methods |
| for _, method := range cors.AllowMethods { |
| if method == "" || !validHTTPMethods[strings.ToUpper(method)] { |
| return errors.New("allow-methods contains invalid HTTP method") |
| } |
| } |
| |
| // Validate headers (both allow and expose) |
| if err := validateHeaders(cors.AllowHeaders, "allow-headers"); err != nil { |
| return err |
| } |
| if err := validateHeaders(cors.ExposeHeaders, "expose-headers"); err != nil { |
| return err |
| } |
| |
| if cors.MaxAge < 0 { |
| return errors.New("max-age cannot be negative") |
| } |
| |
| return nil |
| } |
| |
| func validateHeaders(headers []string, fieldName string) error { |
| for _, header := range headers { |
| if strings.TrimSpace(header) == "" { |
| return errors.New(fieldName + " cannot contain empty string") |
| } |
| } |
| return nil |
| } |
| |
| func validateOrigin(origin string) error { |
| // Allow wildcard |
| if origin == "*" { |
| return nil |
| } |
| |
| // Check for whitespace |
| if strings.ContainsAny(origin, " \t\n\r") { |
| return errors.New("origin contains whitespace") |
| } |
| |
| // Handle subdomain wildcard (*.example.com or https://*.example.com) |
| if strings.Contains(origin, "*") { |
| return validateWildcardOrigin(origin) |
| } |
| |
| // Validate URL format |
| if strings.Contains(origin, "://") { |
| u, err := url.Parse(origin) |
| if err != nil || u.Scheme == "" || u.Host == "" { |
| return errors.New("invalid URL format") |
| } |
| } |
| |
| return nil |
| } |
| |
| func validateWildcardOrigin(origin string) error { |
| // Must be *.domain or scheme://*.domain |
| if !strings.HasPrefix(origin, "*.") && !strings.Contains(origin, "://*.") { |
| return errors.New("wildcard must be at start: '*.domain' or 'scheme://*.domain'") |
| } |
| |
| // Extract domain part after *. |
| var domain string |
| if strings.Contains(origin, "://*.") { |
| parts := strings.SplitN(origin, "://*.", 2) |
| if len(parts) != 2 || parts[1] == "" { |
| return errors.New("invalid subdomain wildcard format") |
| } |
| domain = parts[1] |
| } else { |
| domain = origin[2:] |
| if domain == "" { |
| return errors.New("invalid subdomain wildcard format") |
| } |
| } |
| |
| // Only single wildcard allowed |
| if strings.Contains(domain, "*") { |
| return errors.New("only single wildcard at subdomain level is allowed") |
| } |
| |
| return nil |
| } |
| |
| type OpenAPIOption func(*global.OpenAPIConfig) |
| |
| func WithOpenAPI(opts ...OpenAPIOption) Option { |
| openapi := global.DefaultOpenAPIConfig() |
| for _, opt := range opts { |
| opt(openapi) |
| } |
| return func(opts *Options) { |
| opts.Triple.OpenAPI = openapi |
| } |
| } |
| |
| func OpenAPIEnable() OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| o.Enabled = true |
| } |
| } |
| |
| func OpenAPIInfoTitle(title string) OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| o.InfoTitle = title |
| } |
| } |
| |
| func OpenAPIInfoVersion(version string) OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| o.InfoVersion = version |
| } |
| } |
| |
| func OpenAPIInfoDescription(description string) OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| o.InfoDescription = description |
| } |
| } |
| |
| func OpenAPIDefaultConsumesMediaTypes(types ...string) OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| o.DefaultConsumesMediaTypes = append(o.DefaultConsumesMediaTypes, types...) |
| } |
| } |
| |
| func OpenAPIDefaultProducesMediaTypes(types ...string) OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| o.DefaultProducesMediaTypes = append(o.DefaultProducesMediaTypes, types...) |
| } |
| } |
| |
| func OpenAPIDefaultHttpStatusCodes(codes ...string) OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| o.DefaultHttpStatusCodes = append(o.DefaultHttpStatusCodes, codes...) |
| } |
| } |
| |
| func OpenAPIPath(path string) OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| o.Path = path |
| } |
| } |
| |
| func OpenAPISettings(settings map[string]string) OpenAPIOption { |
| return func(o *global.OpenAPIConfig) { |
| if o.Settings == nil { |
| o.Settings = make(map[string]string) |
| } |
| for k, v := range settings { |
| o.Settings[k] = v |
| } |
| } |
| } |