mirror of
https://github.com/dcarrillo/whatismyip.git
synced 2025-07-01 17:29:27 +00:00
Add feature to get the right client port when using a trusted proxy
This commit is contained in:
@ -27,6 +27,7 @@ type settings struct {
|
||||
TLSCrtPath string
|
||||
TLSKeyPath string
|
||||
TrustedHeader string
|
||||
TrustedPortHeader string
|
||||
EnableSecureHeaders bool
|
||||
Server serverSettings
|
||||
version bool
|
||||
@ -74,6 +75,11 @@ func Setup(args []string) (output string, err error) {
|
||||
"",
|
||||
"Trusted request header for remote IP (e.g. X-Real-IP)",
|
||||
)
|
||||
flags.StringVar(&App.TrustedPortHeader,
|
||||
"trusted-port-header",
|
||||
"",
|
||||
"Trusted request header for remote client port (e.g. X-Real-Port)",
|
||||
)
|
||||
flags.BoolVar(&App.version, "version", false, "Output version information and exit")
|
||||
flags.BoolVar(
|
||||
&App.EnableSecureHeaders,
|
||||
@ -91,21 +97,25 @@ func Setup(args []string) (output string, err error) {
|
||||
return fmt.Sprintf("whatismyip version %s", core.Version), ErrVersion
|
||||
}
|
||||
|
||||
if App.TrustedPortHeader != "" && App.TrustedHeader == "" {
|
||||
return "", fmt.Errorf("truster-header is mandatory when truster-port-header is set\n")
|
||||
}
|
||||
|
||||
if App.GeodbPath.City == "" || App.GeodbPath.ASN == "" {
|
||||
return "", fmt.Errorf("geoip2-city and geoip2-asn parameters are mandatory")
|
||||
return "", fmt.Errorf("geoip2-city and geoip2-asn parameters are mandatory\n")
|
||||
}
|
||||
|
||||
if (App.TLSAddress != "") && (App.TLSCrtPath == "" || App.TLSKeyPath == "") {
|
||||
return "", fmt.Errorf("In order to use TLS -tls-crt and -tls-key flags are mandatory")
|
||||
return "", fmt.Errorf("In order to use TLS -tls-crt and -tls-key flags are mandatory\n")
|
||||
}
|
||||
|
||||
if App.TemplatePath != "" {
|
||||
info, err := os.Stat(App.TemplatePath)
|
||||
if os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("%s no such file or directory", App.TemplatePath)
|
||||
return "", fmt.Errorf("%s no such file or directory\n", App.TemplatePath)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return "", fmt.Errorf("%s must be a file", App.TemplatePath)
|
||||
return "", fmt.Errorf("%s must be a file\n", App.TemplatePath)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,51 +8,51 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseMandatoryFlags(t *testing.T) {
|
||||
var mandatoryFlags = []struct {
|
||||
args []string
|
||||
conf settings
|
||||
}{
|
||||
{
|
||||
[]string{},
|
||||
settings{},
|
||||
},
|
||||
{
|
||||
[]string{"-geoip2-city", "/city-path"},
|
||||
settings{},
|
||||
},
|
||||
{
|
||||
[]string{"-geoip2-asn", "/asn-path"},
|
||||
settings{},
|
||||
},
|
||||
{
|
||||
[]string{
|
||||
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000",
|
||||
},
|
||||
settings{},
|
||||
},
|
||||
{
|
||||
[]string{
|
||||
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000",
|
||||
"-tls-crt", "/crt-path",
|
||||
},
|
||||
settings{},
|
||||
},
|
||||
{
|
||||
[]string{
|
||||
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000",
|
||||
"-tls-key", "/key-path",
|
||||
},
|
||||
settings{},
|
||||
},
|
||||
{
|
||||
[]string{
|
||||
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-bind", ":8000",
|
||||
"-trusted-port-header", "port-header",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range mandatoryFlags {
|
||||
t.Run(strings.Join(tt.args, " "), func(t *testing.T) {
|
||||
_, err := Setup(tt.args)
|
||||
assert.NotNil(t, err)
|
||||
require.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "mandatory")
|
||||
})
|
||||
}
|
||||
@ -70,13 +70,7 @@ func TestParseFlags(t *testing.T) {
|
||||
City: "/city-path",
|
||||
ASN: "/asn-path",
|
||||
},
|
||||
TemplatePath: "",
|
||||
BindAddress: ":8080",
|
||||
TLSAddress: "",
|
||||
TLSCrtPath: "",
|
||||
TLSKeyPath: "",
|
||||
TrustedHeader: "",
|
||||
EnableSecureHeaders: false,
|
||||
BindAddress: ":8080",
|
||||
Server: serverSettings{
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
@ -90,13 +84,7 @@ func TestParseFlags(t *testing.T) {
|
||||
City: "/city-path",
|
||||
ASN: "/asn-path",
|
||||
},
|
||||
TemplatePath: "",
|
||||
BindAddress: ":8001",
|
||||
TLSAddress: "",
|
||||
TLSCrtPath: "",
|
||||
TLSKeyPath: "",
|
||||
TrustedHeader: "",
|
||||
EnableSecureHeaders: false,
|
||||
BindAddress: ":8001",
|
||||
Server: serverSettings{
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
@ -113,13 +101,29 @@ func TestParseFlags(t *testing.T) {
|
||||
City: "/city-path",
|
||||
ASN: "/asn-path",
|
||||
},
|
||||
TemplatePath: "",
|
||||
BindAddress: ":8080",
|
||||
TLSAddress: ":9000",
|
||||
TLSCrtPath: "/crt-path",
|
||||
TLSKeyPath: "/key-path",
|
||||
TrustedHeader: "",
|
||||
EnableSecureHeaders: false,
|
||||
BindAddress: ":8080",
|
||||
TLSAddress: ":9000",
|
||||
TLSCrtPath: "/crt-path",
|
||||
TLSKeyPath: "/key-path",
|
||||
Server: serverSettings{
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
[]string{
|
||||
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path",
|
||||
"-trusted-header", "header", "-trusted-port-header", "port-header",
|
||||
},
|
||||
settings{
|
||||
GeodbPath: geodbPath{
|
||||
City: "/city-path",
|
||||
ASN: "/asn-path",
|
||||
},
|
||||
BindAddress: ":8080",
|
||||
TrustedHeader: "header",
|
||||
TrustedPortHeader: "port-header",
|
||||
Server: serverSettings{
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
@ -136,11 +140,7 @@ func TestParseFlags(t *testing.T) {
|
||||
City: "/city-path",
|
||||
ASN: "/asn-path",
|
||||
},
|
||||
TemplatePath: "",
|
||||
BindAddress: ":8080",
|
||||
TLSAddress: "",
|
||||
TLSCrtPath: "",
|
||||
TLSKeyPath: "",
|
||||
TrustedHeader: "header",
|
||||
EnableSecureHeaders: true,
|
||||
Server: serverSettings{
|
||||
@ -154,7 +154,7 @@ func TestParseFlags(t *testing.T) {
|
||||
for _, tt := range flags {
|
||||
t.Run(strings.Join(tt.args, " "), func(t *testing.T) {
|
||||
_, err := Setup(tt.args)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, reflect.DeepEqual(App, tt.conf))
|
||||
})
|
||||
}
|
||||
@ -192,6 +192,6 @@ func TestParseFlagTemplate(t *testing.T) {
|
||||
"-template", "/",
|
||||
}
|
||||
_, err = Setup(flags)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must be a file")
|
||||
}
|
||||
|
Reference in New Issue
Block a user