diff --git a/cmd/whatismyip.go b/cmd/whatismyip.go index 2b55d65..9168747 100644 --- a/cmd/whatismyip.go +++ b/cmd/whatismyip.go @@ -3,6 +3,8 @@ package main import ( "context" "errors" + "flag" + "fmt" "log" "net/http" "os" @@ -24,7 +26,15 @@ var ( ) func main() { - setting.Setup() + o, err := setting.Setup(os.Args[1:]) + if err == flag.ErrHelp || err == setting.ErrVersion { + fmt.Print(o) + os.Exit(0) + } else if err != nil { + fmt.Print(err) + os.Exit(1) + } + models.Setup(setting.App.GeodbPath.City, setting.App.GeodbPath.ASN) setupEngine() router.SetupTemplate(engine) diff --git a/internal/setting/app.go b/internal/setting/app.go index 5b6c680..1d454c1 100644 --- a/internal/setting/app.go +++ b/internal/setting/app.go @@ -1,6 +1,8 @@ package setting import ( + "bytes" + "errors" "flag" "fmt" "os" @@ -26,74 +28,75 @@ type settings struct { TLSKeyPath string TrustedHeader string Server serverSettings + version bool } const defaultAddress = ":8080" -var App *settings +var ErrVersion = errors.New("setting: version requested") +var App = settings{ + // hard-coded for the time being + Server: serverSettings{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + }, +} -func Setup() { - city := flag.String("geoip2-city", "", "Path to GeoIP2 city database") - asn := flag.String("geoip2-asn", "", "Path to GeoIP2 ASN database") - template := flag.String("template", "", "Path to template file") - address := flag.String( +func Setup(args []string) (output string, err error) { + flags := flag.NewFlagSet("whatismyip", flag.ContinueOnError) + var buf bytes.Buffer + flags.SetOutput(&buf) + + flags.StringVar(&App.GeodbPath.City, "geoip2-city", "", "Path to GeoIP2 city database") + flags.StringVar(&App.GeodbPath.ASN, "geoip2-asn", "", "Path to GeoIP2 ASN database") + flags.StringVar(&App.TemplatePath, "template", "", "Path to template file") + flags.StringVar( + &App.BindAddress, "bind", defaultAddress, "Listening address (see https://pkg.go.dev/net?#Listen)", ) - addressTLS := flag.String( + flags.StringVar( + &App.TLSAddress, "tls-bind", "", "Listening address for TLS (see https://pkg.go.dev/net?#Listen)", ) - tlsCrtPath := flag.String("tls-crt", "", "When using TLS, path to certificate file") - tlsKeyPath := flag.String("tls-key", "", "When using TLS, path to private key file") - trustedHeader := flag.String( + flags.StringVar(&App.TLSCrtPath, "tls-crt", "", "When using TLS, path to certificate file") + flags.StringVar(&App.TLSKeyPath, "tls-key", "", "When using TLS, path to private key file") + flags.StringVar(&App.TrustedHeader, "trusted-header", "", "Trusted request header for remote IP (e.g. X-Real-IP)", ) - ver := flag.Bool("version", false, "Output version information and exit") + flags.BoolVar(&App.version, "version", false, "Output version information and exit") - flag.Parse() - - if *ver { - fmt.Printf("whatismyip version %s", core.Version) - os.Exit(0) + err = flags.Parse(args) + if err != nil { + return buf.String(), err } - if *city == "" || *asn == "" { - exitWithError("geoip2-city and geoip2-asn parameters are mandatory") + if App.version { + return fmt.Sprintf("whatismyip version %s", core.Version), ErrVersion } - if (*addressTLS != "") && (*tlsCrtPath == "" || *tlsKeyPath == "") { - exitWithError("In order to use TLS -tls-crt and -tls-key flags are mandatory") + if App.GeodbPath.City == "" || App.GeodbPath.ASN == "" { + return "", fmt.Errorf("geoip2-city and geoip2-asn parameters are mandatory") } - if *template != "" { - info, err := os.Stat(*template) - if os.IsNotExist(err) || info.IsDir() { - exitWithError(*template + " doesn't exist or it's not a file") + if (App.TLSAddress != "") && (App.TLSCrtPath == "" || App.TLSKeyPath == "") { + return "", fmt.Errorf("In order to use TLS -tls-crt and -tls-key flags are mandatory") + } + + if App.TemplatePath != "" { + info, err := os.Stat(App.TemplatePath) + if os.IsNotExist(err) { + return "", fmt.Errorf("%s no such file or directory", App.TemplatePath) + } + if info.IsDir() { + return "", fmt.Errorf("%s must be a file", App.TemplatePath) } } - App = &settings{ - GeodbPath: geodbPath{City: *city, ASN: *asn}, - TemplatePath: *template, - BindAddress: *address, - TLSAddress: *addressTLS, - TLSCrtPath: *tlsCrtPath, - TLSKeyPath: *tlsKeyPath, - TrustedHeader: *trustedHeader, - Server: serverSettings{ - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - }, - } -} - -func exitWithError(error string) { - fmt.Printf("%s\n\n", error) - flag.Usage() - os.Exit(1) + return buf.String(), nil } diff --git a/internal/setting/app_test.go b/internal/setting/app_test.go new file mode 100644 index 0000000..7530840 --- /dev/null +++ b/internal/setting/app_test.go @@ -0,0 +1,193 @@ +package setting + +import ( + "flag" + "reflect" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseMandatoryFlags(t *testing.T) { + var mandatoryFlags = []struct { + args []string + conf settings + }{ + { + []string{}, + settings{}, + }, + { + []string{"-geoip2-city", "/city-path"}, + settings{}, + }, + { + []string{"-geoip2-asn", "/asn-path"}, + settings{}, + }, + { + []string{ + "-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000", + }, + settings{}, + }, + { + []string{ + "-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000", + "-tls-crt", "/crt-path", + }, + settings{}, + }, + { + []string{ + "-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000", + "-tls-key", "/key-path", + }, + settings{}, + }, + } + + for _, tt := range mandatoryFlags { + t.Run(strings.Join(tt.args, " "), func(t *testing.T) { + _, err := Setup(tt.args) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "mandatory") + }) + } +} + +func TestParseFlags(t *testing.T) { + var flags = []struct { + args []string + conf settings + }{ + { + []string{"-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path"}, + settings{ + GeodbPath: geodbPath{ + City: "/city-path", + ASN: "/asn-path", + }, + TemplatePath: "", + BindAddress: ":8080", + TLSAddress: "", + TLSCrtPath: "", + TLSKeyPath: "", + TrustedHeader: "", + Server: serverSettings{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + }, + }, + }, + { + []string{"-bind", ":8001", "-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path"}, + settings{ + GeodbPath: geodbPath{ + City: "/city-path", + ASN: "/asn-path", + }, + TemplatePath: "", + BindAddress: ":8001", + TLSAddress: "", + TLSCrtPath: "", + TLSKeyPath: "", + TrustedHeader: "", + Server: serverSettings{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + }, + }, + }, + { + []string{ + "-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", "-tls-bind", ":9000", + "-tls-crt", "/crt-path", "-tls-key", "/key-path", + }, + settings{ + GeodbPath: geodbPath{ + City: "/city-path", + ASN: "/asn-path", + }, + TemplatePath: "", + BindAddress: ":8080", + TLSAddress: ":9000", + TLSCrtPath: "/crt-path", + TLSKeyPath: "/key-path", + TrustedHeader: "", + Server: serverSettings{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + }, + }, + }, + { + []string{ + "-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", + "-trusted-header", "header", + }, + settings{ + GeodbPath: geodbPath{ + City: "/city-path", + ASN: "/asn-path", + }, + TemplatePath: "", + BindAddress: ":8080", + TLSAddress: "", + TLSCrtPath: "", + TLSKeyPath: "", + TrustedHeader: "header", + Server: serverSettings{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + }, + }, + }, + } + + for _, tt := range flags { + t.Run(strings.Join(tt.args, " "), func(t *testing.T) { + _, err := Setup(tt.args) + assert.Nil(t, err) + assert.True(t, reflect.DeepEqual(App, tt.conf)) + }) + } +} + +func TestParseFlagsUsage(t *testing.T) { + var usageArgs = []string{"-help", "-h", "--help"} + + for _, arg := range usageArgs { + t.Run(arg, func(t *testing.T) { + output, err := Setup([]string{arg}) + assert.ErrorIs(t, err, flag.ErrHelp) + assert.Contains(t, output, "Usage of") + }) + } +} + +func TestParseFlagVersion(t *testing.T) { + output, err := Setup([]string{"-version"}) + assert.ErrorIs(t, err, ErrVersion) + assert.Contains(t, output, "whatismyip version") +} + +func TestParseFlagTemplate(t *testing.T) { + flags := []string{ + "-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", + "-template", "/template-path", + } + _, err := Setup(flags) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no such file or directory") + + flags = []string{ + "-geoip2-city", "/city-path", "-geoip2-asn", "/asn-path", + "-template", "/", + } + _, err = Setup(flags) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be a file") +}