blob: 3c8f445cc38bccb4551e6943f29a6f9e80da287e [file] [log] [blame]
// Copyright Istio Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package forwarder
import (
"bytes"
"context"
"fmt"
"net"
"net/url"
"strings"
)
var _ protocol = &dnsProtocol{}
type dnsProtocol struct{}
type dnsRequest struct {
hostname string
dnsServer string
query string
protocol string
}
func checkIn(got string, want ...string) error {
for _, w := range want {
if w == got {
return nil
}
}
return fmt.Errorf("got value %q, wanted one of %v", got, want)
}
func parseRequest(inputURL string) (dnsRequest, error) {
req := dnsRequest{}
u, err := url.Parse(inputURL)
if err != nil {
return req, err
}
qp, err := url.ParseQuery(u.RawQuery)
if err != nil {
return req, err
}
req.protocol = qp.Get("protocol")
if err := checkIn(req.protocol, "", "udp", "tcp"); err != nil {
return req, err
}
req.dnsServer = qp.Get("server")
if req.dnsServer != "" {
if _, _, err := net.SplitHostPort(req.dnsServer); err != nil && strings.Contains(err.Error(), "missing port in address") {
req.dnsServer += ":53"
}
}
req.hostname = u.Host
req.query = qp.Get("query")
if err := checkIn(req.query, "", "A", "AAAA"); err != nil {
return req, err
}
return req, nil
}
func (c *dnsProtocol) makeRequest(ctx context.Context, rreq *request) (string, error) {
req, err := parseRequest(rreq.URL)
if err != nil {
return "", err
}
r := newResolver(rreq.Timeout, req.protocol, req.dnsServer)
nt := func() string {
switch req.query {
case "A":
return "ip4"
case "AAAA":
return "ip6"
default:
return "ip"
}
}()
ctx, cancel := context.WithTimeout(ctx, rreq.Timeout)
defer cancel()
ips, err := r.LookupIP(ctx, nt, req.hostname)
if err != nil {
return "", err
}
var outBuffer bytes.Buffer
outBuffer.WriteString(fmt.Sprintf("[%d] Hostname=%s\n", rreq.RequestID, req.hostname))
outBuffer.WriteString(fmt.Sprintf("[%d] Protocol=%s\n", rreq.RequestID, req.protocol))
outBuffer.WriteString(fmt.Sprintf("[%d] Query=%s\n", rreq.RequestID, req.query))
outBuffer.WriteString(fmt.Sprintf("[%d] DnsServer=%s\n", rreq.RequestID, req.dnsServer))
for n, i := range ips {
outBuffer.WriteString(fmt.Sprintf("[%d body] Response%d=%s\n", rreq.RequestID, n, i.String()))
}
return outBuffer.String(), nil
}
func (c *dnsProtocol) Close() error {
return nil
}