Add feature to get the right client port when using a trusted proxy

This commit is contained in:
2022-05-01 19:47:27 +02:00
parent 9070e9a2c2
commit 7c70abf07f
7 changed files with 178 additions and 64 deletions

View File

@ -27,6 +27,7 @@ type settings struct {
TLSCrtPath string
TLSKeyPath string
TrustedHeader string
TrustedPortHeader string
EnableSecureHeaders bool
Server serverSettings
version bool
@ -74,6 +75,11 @@ func Setup(args []string) (output string, err error) {
"",
"Trusted request header for remote IP (e.g. X-Real-IP)",
)
flags.StringVar(&App.TrustedPortHeader,
"trusted-port-header",
"",
"Trusted request header for remote client port (e.g. X-Real-Port)",
)
flags.BoolVar(&App.version, "version", false, "Output version information and exit")
flags.BoolVar(
&App.EnableSecureHeaders,
@ -91,21 +97,25 @@ func Setup(args []string) (output string, err error) {
return fmt.Sprintf("whatismyip version %s", core.Version), ErrVersion
}
if App.TrustedPortHeader != "" && App.TrustedHeader == "" {
return "", fmt.Errorf("truster-header is mandatory when truster-port-header is set\n")
}
if App.GeodbPath.City == "" || App.GeodbPath.ASN == "" {
return "", fmt.Errorf("geoip2-city and geoip2-asn parameters are mandatory")
return "", fmt.Errorf("geoip2-city and geoip2-asn parameters are mandatory\n")
}
if (App.TLSAddress != "") && (App.TLSCrtPath == "" || App.TLSKeyPath == "") {
return "", fmt.Errorf("In order to use TLS -tls-crt and -tls-key flags are mandatory")
return "", fmt.Errorf("In order to use TLS -tls-crt and -tls-key flags are mandatory\n")
}
if App.TemplatePath != "" {
info, err := os.Stat(App.TemplatePath)
if os.IsNotExist(err) {
return "", fmt.Errorf("%s no such file or directory", App.TemplatePath)
return "", fmt.Errorf("%s no such file or directory\n", App.TemplatePath)
}
if info.IsDir() {
return "", fmt.Errorf("%s must be a file", App.TemplatePath)
return "", fmt.Errorf("%s must be a file\n", App.TemplatePath)
}
}

View File

@ -8,51 +8,51 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseMandatoryFlags(t *testing.T) {
var mandatoryFlags = []struct {
args []string
conf settings
}{
{
[]string{},
settings{},
},
{
[]string{"-geoip2-city", "/city-path"},
settings{},
},
{
[]string{"-geoip2-asn", "/asn-path"},
settings{},
},
{
[]string{
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000",
},
settings{},
},
{
[]string{
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000",
"-tls-crt", "/crt-path",
},
settings{},
},
{
[]string{
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000",
"-tls-key", "/key-path",
},
settings{},
},
{
[]string{
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-bind", ":8000",
"-trusted-port-header", "port-header",
},
},
}
for _, tt := range mandatoryFlags {
t.Run(strings.Join(tt.args, " "), func(t *testing.T) {
_, err := Setup(tt.args)
assert.NotNil(t, err)
require.NotNil(t, err)
assert.Contains(t, err.Error(), "mandatory")
})
}
@ -70,13 +70,7 @@ func TestParseFlags(t *testing.T) {
City: "/city-path",
ASN: "/asn-path",
},
TemplatePath: "",
BindAddress: ":8080",
TLSAddress: "",
TLSCrtPath: "",
TLSKeyPath: "",
TrustedHeader: "",
EnableSecureHeaders: false,
BindAddress: ":8080",
Server: serverSettings{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
@ -90,13 +84,7 @@ func TestParseFlags(t *testing.T) {
City: "/city-path",
ASN: "/asn-path",
},
TemplatePath: "",
BindAddress: ":8001",
TLSAddress: "",
TLSCrtPath: "",
TLSKeyPath: "",
TrustedHeader: "",
EnableSecureHeaders: false,
BindAddress: ":8001",
Server: serverSettings{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
@ -113,13 +101,29 @@ func TestParseFlags(t *testing.T) {
City: "/city-path",
ASN: "/asn-path",
},
TemplatePath: "",
BindAddress: ":8080",
TLSAddress: ":9000",
TLSCrtPath: "/crt-path",
TLSKeyPath: "/key-path",
TrustedHeader: "",
EnableSecureHeaders: false,
BindAddress: ":8080",
TLSAddress: ":9000",
TLSCrtPath: "/crt-path",
TLSKeyPath: "/key-path",
Server: serverSettings{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
},
},
},
{
[]string{
"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path",
"-trusted-header", "header", "-trusted-port-header", "port-header",
},
settings{
GeodbPath: geodbPath{
City: "/city-path",
ASN: "/asn-path",
},
BindAddress: ":8080",
TrustedHeader: "header",
TrustedPortHeader: "port-header",
Server: serverSettings{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
@ -136,11 +140,7 @@ func TestParseFlags(t *testing.T) {
City: "/city-path",
ASN: "/asn-path",
},
TemplatePath: "",
BindAddress: ":8080",
TLSAddress: "",
TLSCrtPath: "",
TLSKeyPath: "",
TrustedHeader: "header",
EnableSecureHeaders: true,
Server: serverSettings{
@ -154,7 +154,7 @@ func TestParseFlags(t *testing.T) {
for _, tt := range flags {
t.Run(strings.Join(tt.args, " "), func(t *testing.T) {
_, err := Setup(tt.args)
assert.Nil(t, err)
require.Nil(t, err)
assert.True(t, reflect.DeepEqual(App, tt.conf))
})
}
@ -192,6 +192,6 @@ func TestParseFlagTemplate(t *testing.T) {
"-template", "/",
}
_, err = Setup(flags)
assert.Error(t, err)
require.Error(t, err)
assert.Contains(t, err.Error(), "must be a file")
}