New whatismydns feature (#29)

This commit is contained in:
2024-04-12 19:26:48 +02:00
committed by GitHub
parent b11f15ecfe
commit d13ea29071
20 changed files with 1571 additions and 210 deletions

85
router/dns.go Normal file
View File

@ -0,0 +1,85 @@
package router
import (
"fmt"
"net"
"net/http"
"strings"
validator "github.com/dcarrillo/whatismyip/internal/validator/uuid"
"github.com/dcarrillo/whatismyip/service"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/patrickmn/go-cache"
)
type DNSJSONResponse struct {
DNS dnsData `json:"dns"`
}
type dnsData struct {
IP string `json:"ip"`
Country string `json:"country"`
AsnOrganization string `json:"provider"`
}
// TODO
// Implement a proper vhost manager instead of using a middleware
func GetDNSDiscoveryHandler(store *cache.Cache, domain string, redirectPort string) gin.HandlerFunc {
return func(ctx *gin.Context) {
if !strings.HasSuffix(ctx.Request.Host, domain) {
ctx.Next()
return
}
if ctx.Request.Host == domain {
ctx.Redirect(http.StatusFound, fmt.Sprintf("http://%s.%s%s", uuid.New().String(), domain, redirectPort))
ctx.Abort()
return
}
handleDNS(ctx, store)
ctx.Abort()
}
}
func handleDNS(ctx *gin.Context, store *cache.Cache) {
d := strings.Split(ctx.Request.Host, ".")[0]
if !validator.IsValid(d) {
ctx.String(http.StatusNotFound, http.StatusText(http.StatusNotFound))
return
}
v, found := store.Get(d)
if !found {
ctx.String(http.StatusNotFound, http.StatusText(http.StatusNotFound))
return
}
ipStr, ok := v.(string)
if !ok {
ctx.String(http.StatusNotFound, http.StatusText(http.StatusNotFound))
return
}
ip := net.ParseIP(ipStr)
if ip == nil {
ctx.String(http.StatusNotFound, http.StatusText(http.StatusNotFound))
return
}
geo := service.Geo{IP: ip}
j := DNSJSONResponse{
DNS: dnsData{
IP: ipStr,
Country: geo.LookUpCity().Country.Names["en"],
AsnOrganization: geo.LookUpASN().AutonomousSystemOrganization,
},
}
switch ctx.NegotiateFormat(gin.MIMEPlain, gin.MIMEHTML, gin.MIMEJSON) {
case gin.MIMEJSON:
ctx.JSON(http.StatusOK, j)
default:
ctx.String(http.StatusOK, fmt.Sprintf("%s (%s / %s)\n", j.DNS.IP, j.DNS.Country, j.DNS.AsnOrganization))
}
}

140
router/dns_test.go Normal file
View File

@ -0,0 +1,140 @@
package router
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
validator "github.com/dcarrillo/whatismyip/internal/validator/uuid"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
)
func TestGetDNSDiscoveryHandler(t *testing.T) {
store := cache.New(cache.NoExpiration, cache.NoExpiration)
handler := GetDNSDiscoveryHandler(store, domain, "")
t.Run("calls next if host does not have domain suffix", func(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Set(trustedHeader, testIP.ipv4)
req.Host = "example.com"
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
handler(c)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, testIP.ipv4+"\n", w.Body.String())
})
t.Run("redirects if host is domain", func(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Host = domain
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
handler(c)
assert.Equal(t, http.StatusFound, w.Code)
r, err := url.Parse(w.Header().Get("Location"))
assert.NoError(t, err)
assert.True(t, validator.IsValid(strings.Split(r.Host, ".")[0]))
assert.Equal(t, domain, strings.Join(strings.Split(r.Host, ".")[1:], "."))
})
}
func TestHandleDNS(t *testing.T) {
store := cache.New(cache.NoExpiration, cache.NoExpiration)
u := uuid.New().String()
tests := []struct {
name string
subDomain string
stored any
}{
{
name: "not found if the subdomain is not a valid uuid",
subDomain: "not-uuid",
stored: "",
},
{
name: "not found if the ip is not found in the store",
subDomain: u,
stored: "",
},
{
name: "not found if the ip is in store but is not valid",
subDomain: u,
stored: "bogus",
},
{
name: "not found if the store contains no string",
subDomain: u,
stored: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Host = tt.subDomain + "." + domain
if tt.stored != "" {
store.Add(tt.subDomain, tt.stored, cache.DefaultExpiration)
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
handleDNS(c, store)
assert.Equal(t, http.StatusNotFound, w.Code)
})
}
}
func TestAcceptDNSRequest(t *testing.T) {
store := cache.New(cache.NoExpiration, cache.NoExpiration)
tests := []struct {
name string
accept string
want string
}{
{
name: "returns json dns data",
accept: "application/json",
want: jsonDNSIPv4,
},
{
name: "return plan text dns data",
accept: "text/plain",
want: plainDNSIPv4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
u := uuid.New().String()
req.Host = u + "." + domain
req.Header.Add("Accept", tt.accept)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
store.Add(u, testIP.ipv4, cache.DefaultExpiration)
handleDNS(c, store)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, tt.want, w.Body.String())
})
}
}

View File

@ -11,7 +11,6 @@ import (
"github.com/gin-gonic/gin"
)
// JSONResponse maps data as json
type JSONResponse struct {
IP string `json:"ip"`
IPVersion byte `json:"ip_version"`

View File

@ -8,7 +8,6 @@ import (
"github.com/gin-gonic/gin"
)
// SetupTemplate reads and parses a template from file
func SetupTemplate(r *gin.Engine) {
if setting.App.TemplatePath == "" {
t, _ := template.New("home").Parse(home)
@ -19,7 +18,6 @@ func SetupTemplate(r *gin.Engine) {
}
}
// Setup defines the endpoints
func Setup(r *gin.Engine) {
r.GET("/", getRoot)
r.GET("/scan/tcp/:port", scanTCPPort)

View File

@ -34,12 +34,17 @@ var (
text: "text/plain; charset=utf-8",
json: "application/json; charset=utf-8",
}
jsonIPv4 = `{"client_port":"1001","ip":"81.2.69.192","ip_version":4,"country":"United Kingdom","country_code":"GB","city":"London","latitude":51.5142,"longitude":-0.0931,"postal_code":"","time_zone":"Europe/London","asn":0,"asn_organization":"","host":"test", "headers": {}}`
jsonIPv6 = `{"asn":3352, "asn_organization":"TELEFONICA DE ESPANA", "city":"", "client_port":"1001", "country":"", "country_code":"", "host":"test", "ip":"2a02:9000::1", "ip_version":6, "latitude":0, "longitude":0, "postal_code":"", "time_zone":"", "headers": {}}`
jsonIPv4 = `{"client_port":"1001","ip":"81.2.69.192","ip_version":4,"country":"United Kingdom","country_code":"GB","city":"London","latitude":51.5142,"longitude":-0.0931,"postal_code":"","time_zone":"Europe/London","asn":0,"asn_organization":"","host":"test", "headers": {}}`
jsonIPv6 = `{"asn":3352, "asn_organization":"TELEFONICA DE ESPANA", "city":"", "client_port":"1001", "country":"", "country_code":"", "host":"test", "ip":"2a02:9000::1", "ip_version":6, "latitude":0, "longitude":0, "postal_code":"", "time_zone":"", "headers": {}}`
jsonDNSIPv4 = `{"dns":{"ip":"81.2.69.192","country":"United Kingdom","provider":""}}`
plainDNSIPv4 = "81.2.69.192 (United Kingdom / )\n"
)
const trustedHeader = "X-Real-IP"
const trustedPortHeader = "X-Real-Port"
const (
trustedHeader = "X-Real-IP"
trustedPortHeader = "X-Real-Port"
domain = "dns.example.com"
)
func TestMain(m *testing.M) {
app = gin.Default()