chore: Use black formatter (#9)

This commit is contained in:
Daniel Carrillo 2024-06-08 13:43:41 +02:00 committed by Daniel Carrillo
parent e5685511ba
commit a90dbd0123
Signed by: dcarrillo
GPG Key ID: E4CD5C09DAED6E16
8 changed files with 240 additions and 235 deletions

View File

@ -14,7 +14,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@ -43,15 +43,15 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python 3.11 - name: Set up Python 3.12
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: 3.11 python-version: 3.12
- name: Install tools - name: Install tools
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install twine wheel pip install twine wheel setuptools
- name: Build - name: Build
run: | run: |

View File

@ -1,2 +1,2 @@
__version__ = '1.0.6' __version__ = "1.0.7"
__description__ = 'Look up canonical information for AWS IP addresses and networks' __description__ = "Look up canonical information for AWS IP addresses and networks"

View File

@ -9,23 +9,21 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
from dateutil import tz
import requests import requests
from dateutil import tz
from requests.exceptions import RequestException from requests.exceptions import RequestException
from . import __description__ from . import __description__, __version__
from . import __version__
AWS_IP_RANGES_URL = 'https://ip-ranges.amazonaws.com/ip-ranges.json' AWS_IP_RANGES_URL = "https://ip-ranges.amazonaws.com/ip-ranges.json"
CACHE_DIR = Path(Path.home() / '.digaws') CACHE_DIR = Path(Path.home() / ".digaws")
CACHE_FILE = CACHE_DIR / 'ip-ranges.json' CACHE_FILE = CACHE_DIR / "ip-ranges.json"
OUTPUT_FIELDS = ['prefix', 'region', 'service', 'network_border_group'] OUTPUT_FIELDS = ["prefix", "region", "service", "network_border_group"]
logger = logging.getLogger() logger = logging.getLogger()
handler = logging.StreamHandler(sys.stderr) handler = logging.StreamHandler(sys.stderr)
logger.addHandler(handler) logger.addHandler(handler)
handler.setFormatter(logging.Formatter('-- %(levelname)s -- %(message)s')) handler.setFormatter(logging.Formatter("-- %(levelname)s -- %(message)s"))
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -34,51 +32,49 @@ def get_aws_ip_ranges() -> Dict:
headers = {} headers = {}
try: try:
file_time = datetime.fromtimestamp( file_time = datetime.fromtimestamp(CACHE_FILE.stat().st_mtime, tz=tz.UTC).strftime(
CACHE_FILE.stat().st_mtime, "%a, %d %b %Y %H:%M:%S GMT"
tz=tz.UTC).strftime('%a, %d %b %Y %H:%M:%S GMT') )
logger.debug(f'cached file modification time: {file_time}') logger.debug(f"cached file modification time: {file_time}")
headers = {'If-Modified-Since': file_time} headers = {"If-Modified-Since": file_time}
except FileNotFoundError as e: except FileNotFoundError as e:
logger.debug(f'Not found: {CACHE_FILE}: {e}') logger.debug(f"Not found: {CACHE_FILE}: {e}")
pass pass
try: try:
response = requests.get( response = requests.get(url=AWS_IP_RANGES_URL, timeout=5, headers=headers)
url=AWS_IP_RANGES_URL,
timeout=5,
headers=headers
)
if response.status_code == 304: if response.status_code == 304:
try: try:
logger.debug(f'reading cached file {CACHE_FILE}') logger.debug(f"reading cached file {CACHE_FILE}")
with open(CACHE_FILE) as ip_ranges: with open(CACHE_FILE) as ip_ranges:
return json.load(ip_ranges) return json.load(ip_ranges)
except (OSError, IOError, json.JSONDecodeError) as e: except (OSError, IOError, json.JSONDecodeError) as e:
logger.debug(f'ERROR reading {CACHE_FILE}: {e}') logger.debug(f"ERROR reading {CACHE_FILE}: {e}")
raise CachedFileException(str(e)) raise CachedFileException(str(e))
elif response.status_code == 200: elif response.status_code == 200:
try: try:
with open(CACHE_FILE, 'w') as f: with open(CACHE_FILE, "w") as f:
f.write(response.text) f.write(response.text)
except (OSError, IOError) as e: except (OSError, IOError) as e:
logger.warning(e) logger.warning(e)
return response.json() return response.json()
else: else:
msg = f'Unexpected response from {AWS_IP_RANGES_URL}. Status code: ' \ msg = (
f'{response.status_code}. Content: {response.text}' f"Unexpected response from {AWS_IP_RANGES_URL}. Status code: "
f"{response.status_code}. Content: {response.text}"
)
logger.debug(msg) logger.debug(msg)
raise UnexpectedRequestException(msg) raise UnexpectedRequestException(msg)
except RequestException as e: except RequestException as e:
logger.debug(f'ERROR retrieving {AWS_IP_RANGES_URL}: {e}') logger.debug(f"ERROR retrieving {AWS_IP_RANGES_URL}: {e}")
raise e raise e
class CachedFileException(Exception): class CachedFileException(Exception):
def __init__(self, message: str): def __init__(self, message: str):
message = f'Error reading cached ranges {CACHE_FILE}: {message}' message = f"Error reading cached ranges {CACHE_FILE}: {message}"
super(CachedFileException, self).__init__(message) super(CachedFileException, self).__init__(message)
@ -94,141 +90,137 @@ class DigAWSPrettyPrinter:
def plain_print(self) -> None: def plain_print(self) -> None:
for prefix in self.data: for prefix in self.data:
if 'prefix' in self.output_fields: if "prefix" in self.output_fields:
try: try:
print(f'Prefix: {prefix["ip_prefix"]}') print(f'Prefix: {prefix["ip_prefix"]}')
except KeyError: except KeyError:
print(f'IPv6 Prefix: {prefix["ipv6_prefix"]}') print(f'IPv6 Prefix: {prefix["ipv6_prefix"]}')
if 'region' in self.output_fields: if "region" in self.output_fields:
print(f'Region: {prefix["region"]}') print(f'Region: {prefix["region"]}')
if 'service' in self.output_fields: if "service" in self.output_fields:
print(f'Service: {prefix["service"]}') print(f'Service: {prefix["service"]}')
if 'network_border_group' in self.output_fields: if "network_border_group" in self.output_fields:
print(f'Network border group: {prefix["network_border_group"]}') print(f'Network border group: {prefix["network_border_group"]}')
print('') print("")
def json_print(self) -> None: def json_print(self) -> None:
data = [] data = []
for prefix in self.data: for prefix in self.data:
try: try:
prefix['ip_prefix'] prefix["ip_prefix"]
prefix_type = 'ip_prefix' prefix_type = "ip_prefix"
except KeyError: except KeyError:
prefix_type = 'ipv6_prefix' prefix_type = "ipv6_prefix"
item_dict = {} item_dict = {}
if 'prefix' in self.output_fields: if "prefix" in self.output_fields:
item_dict.update({prefix_type: str(prefix[prefix_type])}) item_dict.update({prefix_type: str(prefix[prefix_type])})
if 'region' in self.output_fields: if "region" in self.output_fields:
item_dict.update({'region': prefix['region']}) item_dict.update({"region": prefix["region"]})
if 'service' in self.output_fields: if "service" in self.output_fields:
item_dict.update({'service': prefix['service']}) item_dict.update({"service": prefix["service"]})
if 'network_border_group' in self.output_fields: if "network_border_group" in self.output_fields:
item_dict.update({'network_border_group': prefix['network_border_group']}) item_dict.update({"network_border_group": prefix["network_border_group"]})
data.append(item_dict) data.append(item_dict)
print(json.dumps(data, indent=2)) print(json.dumps(data, indent=2))
class DigAWS(): class DigAWS:
def __init__(self, *, ip_ranges: Dict, output: str = 'plain', output_fields: List[str] = []): def __init__(self, *, ip_ranges: Dict, output: str = "plain", output_fields: List[str] = []):
self.output = output self.output = output
self.output_fields = output_fields self.output_fields = output_fields
self.ip_prefixes = [ self.ip_prefixes = [
{ {
'ip_prefix': ipaddress.IPv4Network(prefix['ip_prefix']), "ip_prefix": ipaddress.IPv4Network(prefix["ip_prefix"]),
'region': prefix['region'], "region": prefix["region"],
'service': prefix['service'], "service": prefix["service"],
'network_border_group': prefix['network_border_group'] "network_border_group": prefix["network_border_group"],
} }
for prefix in ip_ranges['prefixes'] for prefix in ip_ranges["prefixes"]
] ]
self.ipv6_prefixes = [ self.ipv6_prefixes = [
{ {
'ipv6_prefix': ipaddress.IPv6Network(prefix['ipv6_prefix']), "ipv6_prefix": ipaddress.IPv6Network(prefix["ipv6_prefix"]),
'region': prefix['region'], "region": prefix["region"],
'service': prefix['service'], "service": prefix["service"],
'network_border_group': prefix['network_border_group'] "network_border_group": prefix["network_border_group"],
} }
for prefix in ip_ranges['ipv6_prefixes'] for prefix in ip_ranges["ipv6_prefixes"]
] ]
def lookup(self, address: str) -> DigAWSPrettyPrinter: def lookup(self, address: str) -> DigAWSPrettyPrinter:
return DigAWSPrettyPrinter( return DigAWSPrettyPrinter(self._lookup_data(address), self.output_fields)
self._lookup_data(address),
self.output_fields
)
def _lookup_data(self, address: str) -> List[Dict]: def _lookup_data(self, address: str) -> List[Dict]:
addr: Any = None addr: Any = None
try: try:
addr = ipaddress.IPv4Address(address) addr = ipaddress.IPv4Address(address)
data = [prefix for prefix in self.ip_prefixes data = [prefix for prefix in self.ip_prefixes if addr in prefix["ip_prefix"]]
if addr in prefix['ip_prefix']]
except ipaddress.AddressValueError: except ipaddress.AddressValueError:
try: try:
addr = ipaddress.IPv6Address(address) addr = ipaddress.IPv6Address(address)
data = [prefix for prefix in self.ipv6_prefixes data = [prefix for prefix in self.ipv6_prefixes if addr in prefix["ipv6_prefix"]]
if addr in prefix['ipv6_prefix']]
except ipaddress.AddressValueError: except ipaddress.AddressValueError:
try: try:
addr = ipaddress.IPv4Network(address) addr = ipaddress.IPv4Network(address)
data = [prefix for prefix in self.ip_prefixes data = [
if addr.subnet_of(prefix['ip_prefix'])] prefix for prefix in self.ip_prefixes if addr.subnet_of(prefix["ip_prefix"])
]
except (ipaddress.AddressValueError, ValueError): except (ipaddress.AddressValueError, ValueError):
try: try:
addr = ipaddress.IPv6Network(address) addr = ipaddress.IPv6Network(address)
data = [prefix for prefix in self.ipv6_prefixes data = [
if addr.subnet_of(prefix['ipv6_prefix'])] prefix
for prefix in self.ipv6_prefixes
if addr.subnet_of(prefix["ipv6_prefix"])
]
except (ipaddress.AddressValueError, ValueError): except (ipaddress.AddressValueError, ValueError):
raise ValueError(f'Wrong IP or CIDR format: {address}') raise ValueError(f"Wrong IP or CIDR format: {address}")
return data return data
def arguments_parser() -> argparse.ArgumentParser: def arguments_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(add_help=True, description=__description__)
add_help=True,
description=__description__
)
parser.add_argument( parser.add_argument(
'--output', "--output",
metavar='<plain|json>', metavar="<plain|json>",
choices=['plain', 'json'], choices=["plain", "json"],
type=str, type=str,
required=False, required=False,
dest='output', dest="output",
default='plain', default="plain",
help='Formatting style for command output, by default %(default)s' help="Formatting style for command output, by default %(default)s",
) )
parser.add_argument( parser.add_argument(
'--output-fields', "--output-fields",
nargs='*', nargs="*",
choices=OUTPUT_FIELDS, choices=OUTPUT_FIELDS,
required=False, required=False,
dest='output_fields', dest="output_fields",
default=OUTPUT_FIELDS, default=OUTPUT_FIELDS,
help='Print only the given fields' help="Print only the given fields",
) )
parser.add_argument( parser.add_argument(
'--debug', "--debug",
action='store_true', action="store_true",
required=False, required=False,
default=False, default=False,
dest='debug', dest="debug",
help='Enable debug' help="Enable debug",
) )
parser.add_argument( parser.add_argument(
'--version', "--version",
action='version', action="version",
version='%(prog)s {version}'.format(version=__version__) version="%(prog)s {version}".format(version=__version__),
) )
parser.add_argument( parser.add_argument(
'addresses', "addresses",
nargs='+', nargs="+",
metavar='<ip address|cidr>', metavar="<ip address|cidr>",
type=str, type=str,
help='CIDR or IP (v4 or v6) to look up' help="CIDR or IP (v4 or v6) to look up",
) )
return parser return parser
@ -248,7 +240,7 @@ def main():
for address in args.addresses: for address in args.addresses:
responses.append(dig.lookup(address)) responses.append(dig.lookup(address))
if args.output == 'plain': if args.output == "plain":
for response in responses: for response in responses:
response.plain_print() response.plain_print()
else: else:
@ -265,8 +257,9 @@ def main():
ipaddress.AddressValueError, ipaddress.AddressValueError,
ValueError, ValueError,
CachedFileException, CachedFileException,
UnexpectedRequestException) as e: UnexpectedRequestException,
print(f'ERROR: {e}') ) as e:
print(f"ERROR: {e}")
sys.exit(1) sys.exit(1)

View File

@ -1,34 +1,35 @@
import nox import nox
nox.options.sessions = ["lint", "typing", "tests"]
locations = ["noxfile.py", "setup.py", "digaws/", "tests/"]
nox.options.sessions = ['lint', 'typing', 'tests'] lint_common_args = ["--max-line-length", "100"]
locations = ['noxfile.py', 'setup.py', 'digaws/', 'tests/'] black_args = ["--line-length", "100"]
mypy_args = ["--ignore-missing-imports", "--install-types", "--non-interactive"]
lint_common_args = ['--max-line-length', '120'] pytest_args = ["--cov=digaws", "--cov-report=", "tests/"]
mypy_args = ['--ignore-missing-imports', '--install-types', '--non-interactive'] coverage_args = ["report", "--show-missing", "--fail-under=80"]
pytest_args = ['--cov=digaws', '--cov-report=', 'tests/']
coverage_args = ['report', '--show-missing', '--fail-under=80']
@nox.session() @nox.session()
def lint(session): def lint(session):
args = session.posargs or locations args = session.posargs or locations
session.install('pycodestyle', 'flake8', 'flake8-import-order') session.install("pycodestyle", "flake8", "black")
session.run('pycodestyle', *(lint_common_args + args)) session.run("pycodestyle", *(lint_common_args + args))
session.run('flake8', *(lint_common_args + args)) session.run("flake8", *(lint_common_args + args))
session.run("black", "--check", *(black_args + args))
@nox.session() @nox.session()
def typing(session): def typing(session):
args = session.posargs or locations args = session.posargs or locations
session.install('mypy') session.install("mypy")
session.run('mypy', *(mypy_args + args)) session.run("mypy", *(mypy_args + args))
@nox.session() @nox.session()
def tests(session): def tests(session):
args = session.posargs args = session.posargs
session.install('-r', 'requirements_test.txt') session.install("-r", "requirements_test.txt")
session.run('pytest', *(pytest_args + args)) session.run("pytest", *(pytest_args + args))
session.run('coverage', *coverage_args) session.run("coverage", *coverage_args)

View File

@ -1,36 +1,34 @@
from digaws import __description__, __version__
from setuptools import setup from setuptools import setup
from digaws import __description__, __version__
def get_long_description() -> str: def get_long_description() -> str:
with open('README.md', 'r', encoding='utf-8') as fh: with open("README.md", "r", encoding="utf-8") as fh:
return fh.read() return fh.read()
setup( setup(
name='digaws', name="digaws",
version=__version__, version=__version__,
description=__description__, description=__description__,
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
url='http://github.com/dcarrillo/digaws', url="http://github.com/dcarrillo/digaws",
author='Daniel Carrillo', author="Daniel Carrillo",
author_email='daniel.carrillo@gmail.com', author_email="daniel.carrillo@gmail.com",
license='Apache Software License', license="Apache Software License",
packages=['digaws'], packages=["digaws"],
zip_safe=False, zip_safe=False,
classifiers=[ classifiers=[
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
'License :: OSI Approved :: Apache Software License', "License :: OSI Approved :: Apache Software License",
'Operating System :: OS Independent', "Operating System :: OS Independent",
], ],
python_requires='>=3.8', python_requires=">=3.8",
entry_points={ entry_points={"console_scripts": ["digaws=digaws.digaws:main"]},
'console_scripts': ['digaws=digaws.digaws:main']
},
install_requires=[ install_requires=[
'python-dateutil~=2.8', "python-dateutil~=2.8",
'requests~=2.25', "requests~=2.25",
] ],
) )

View File

@ -1,6 +1,6 @@
import ipaddress import ipaddress
AWS_IP_RANGES = ''' AWS_IP_RANGES = """
{ {
"syncToken": "1608245058", "syncToken": "1608245058",
"createDate": "2020-12-17-22-44-18", "createDate": "2020-12-17-22-44-18",
@ -45,57 +45,57 @@ AWS_IP_RANGES = '''
} }
] ]
} }
''' """
AWS_IPV4_RANGES_OBJ = [ AWS_IPV4_RANGES_OBJ = [
{ {
'ip_prefix': ipaddress.IPv4Network('52.93.178.234/32'), "ip_prefix": ipaddress.IPv4Network("52.93.178.234/32"),
'region': 'us-west-1', "region": "us-west-1",
'service': 'AMAZON', "service": "AMAZON",
'network_border_group': 'us-west-1' "network_border_group": "us-west-1",
}, },
{ {
'ip_prefix': ipaddress.IPv4Network('52.94.76.0/22'), "ip_prefix": ipaddress.IPv4Network("52.94.76.0/22"),
'region': 'us-west-2', "region": "us-west-2",
'service': 'AMAZON', "service": "AMAZON",
'network_border_group': 'us-west-2' "network_border_group": "us-west-2",
} },
] ]
AWS_IPV6_RANGES_OBJ = [ AWS_IPV6_RANGES_OBJ = [
{ {
'ipv6_prefix': ipaddress.IPv6Network('2600:1f00:c000::/40'), "ipv6_prefix": ipaddress.IPv6Network("2600:1f00:c000::/40"),
'region': 'us-west-1', "region": "us-west-1",
'service': 'AMAZON', "service": "AMAZON",
'network_border_group': 'us-west-1' "network_border_group": "us-west-1",
}, },
{ {
'ipv6_prefix': ipaddress.IPv6Network('2600:1f01:4874::/47'), "ipv6_prefix": ipaddress.IPv6Network("2600:1f01:4874::/47"),
'region': 'us-west-2', "region": "us-west-2",
'service': 'AMAZON', "service": "AMAZON",
'network_border_group': 'us-west-2' "network_border_group": "us-west-2",
}, },
{ {
'ipv6_prefix': ipaddress.IPv6Network('2600:1f14:fff:f800::/53'), "ipv6_prefix": ipaddress.IPv6Network("2600:1f14:fff:f800::/53"),
'region': 'us-west-2', "region": "us-west-2",
'service': 'ROUTE53_HEALTHCHECKS', "service": "ROUTE53_HEALTHCHECKS",
'network_border_group': 'us-west-2' "network_border_group": "us-west-2",
}, },
{ {
'ipv6_prefix': ipaddress.IPv6Network('2600:1f14::/35'), "ipv6_prefix": ipaddress.IPv6Network("2600:1f14::/35"),
'region': 'us-west-2', "region": "us-west-2",
'service': 'EC2', "service": "EC2",
'network_border_group': 'us-west-2' "network_border_group": "us-west-2",
} },
] ]
LAST_MODIFIED_TIME = 'Thu, 17 Dec 2020 23:22:33 GMT' LAST_MODIFIED_TIME = "Thu, 17 Dec 2020 23:22:33 GMT"
RESPONSE_PLAIN_PRINT = '''Prefix: 52.94.76.0/22 RESPONSE_PLAIN_PRINT = """Prefix: 52.94.76.0/22
Region: us-west-2 Region: us-west-2
Service: AMAZON Service: AMAZON
Network border group: us-west-2 Network border group: us-west-2
''' """
RESPONSE_JSON_PRINT = '''[ RESPONSE_JSON_PRINT = """[
{ {
"ipv6_prefix": "2600:1f14:fff:f800::/53", "ipv6_prefix": "2600:1f14:fff:f800::/53",
"region": "us-west-2", "region": "us-west-2",
@ -109,9 +109,9 @@ RESPONSE_JSON_PRINT = '''[
"network_border_group": "us-west-2" "network_border_group": "us-west-2"
} }
] ]
''' """
RESPONSE_JSON_FIELDS_PRINT = '''[ RESPONSE_JSON_FIELDS_PRINT = """[
{ {
"service": "ROUTE53_HEALTHCHECKS", "service": "ROUTE53_HEALTHCHECKS",
"network_border_group": "us-west-2" "network_border_group": "us-west-2"
@ -121,9 +121,9 @@ RESPONSE_JSON_FIELDS_PRINT = '''[
"network_border_group": "us-west-2" "network_border_group": "us-west-2"
} }
] ]
''' """
RESPONSE_JSON_JOINED_PRINT = '''[ RESPONSE_JSON_JOINED_PRINT = """[
{ {
"ip_prefix": "52.94.76.0/22", "ip_prefix": "52.94.76.0/22",
"region": "us-west-2", "region": "us-west-2",
@ -143,4 +143,4 @@ RESPONSE_JSON_JOINED_PRINT = '''[
"network_border_group": "us-west-2" "network_border_group": "us-west-2"
} }
] ]
''' """

View File

@ -1,24 +1,22 @@
import json import json
import sys import sys
import digaws.digaws as digaws
from digaws import __description__, __version__
import pytest import pytest
import digaws.digaws as digaws
import tests import tests
from digaws import __description__, __version__
@pytest.fixture @pytest.fixture
def test_dig(): def test_dig():
return digaws.DigAWS(ip_ranges=json.loads( return digaws.DigAWS(
tests.AWS_IP_RANGES), ip_ranges=json.loads(tests.AWS_IP_RANGES), output_fields=digaws.OUTPUT_FIELDS
output_fields=digaws.OUTPUT_FIELDS
) )
def test_cli(capsys): def test_cli(capsys):
sys.argv = ['digaws', '-h'] sys.argv = ["digaws", "-h"]
try: try:
digaws.main() digaws.main()
except SystemExit as e: except SystemExit as e:
@ -28,19 +26,24 @@ def test_cli(capsys):
def test_cli_version(capsys, mocker): def test_cli_version(capsys, mocker):
sys.argv = ['digaws', '--version'] sys.argv = ["digaws", "--version"]
try: try:
digaws.main() digaws.main()
except SystemExit as e: except SystemExit as e:
out, _ = capsys.readouterr() out, _ = capsys.readouterr()
assert out == f'digaws {__version__}\n' assert out == f"digaws {__version__}\n"
assert e.code == 0 assert e.code == 0
def test_cli_invocation(capsys, mocker): def test_cli_invocation(capsys, mocker):
sys.argv = ['digaws', '52.94.76.0/22', '2600:1f14:fff:f810:a1c1:f507:a2d1:2dd8', sys.argv = [
'--output', 'json'] "digaws",
mocker.patch('digaws.digaws.get_aws_ip_ranges', return_value=json.loads(tests.AWS_IP_RANGES)) "52.94.76.0/22",
"2600:1f14:fff:f810:a1c1:f507:a2d1:2dd8",
"--output",
"json",
]
mocker.patch("digaws.digaws.get_aws_ip_ranges", return_value=json.loads(tests.AWS_IP_RANGES))
digaws.main() digaws.main()
out, _ = capsys.readouterr() out, _ = capsys.readouterr()
@ -48,18 +51,30 @@ def test_cli_invocation(capsys, mocker):
def test_cli_output_plain_fields_invocation(capsys, mocker): def test_cli_output_plain_fields_invocation(capsys, mocker):
sys.argv = ['digaws', '52.94.76.0/22', '--output=plain', '--output-fields', 'region'] sys.argv = [
mocker.patch('digaws.digaws.get_aws_ip_ranges', return_value=json.loads(tests.AWS_IP_RANGES)) "digaws",
"52.94.76.0/22",
"--output=plain",
"--output-fields",
"region",
]
mocker.patch("digaws.digaws.get_aws_ip_ranges", return_value=json.loads(tests.AWS_IP_RANGES))
digaws.main() digaws.main()
out, _ = capsys.readouterr() out, _ = capsys.readouterr()
assert out == 'Region: us-west-2\n\n' assert out == "Region: us-west-2\n\n"
def test_cli_output_json_fields_invocation(capsys, mocker): def test_cli_output_json_fields_invocation(capsys, mocker):
sys.argv = ['digaws', '2600:1f14:fff:f810:a1c1:f507:a2d1:2dd8', '--output=json', sys.argv = [
'--output-fields', 'service', 'network_border_group'] "digaws",
mocker.patch('digaws.digaws.get_aws_ip_ranges', return_value=json.loads(tests.AWS_IP_RANGES)) "2600:1f14:fff:f810:a1c1:f507:a2d1:2dd8",
"--output=json",
"--output-fields",
"service",
"network_border_group",
]
mocker.patch("digaws.digaws.get_aws_ip_ranges", return_value=json.loads(tests.AWS_IP_RANGES))
digaws.main() digaws.main()
out, _ = capsys.readouterr() out, _ = capsys.readouterr()
@ -72,28 +87,28 @@ def test_dig_aws_construct(test_dig):
def test_lookup(test_dig): def test_lookup(test_dig):
assert str(test_dig._lookup_data('52.94.76.1')[0]['ip_prefix']) == '52.94.76.0/22' assert str(test_dig._lookup_data("52.94.76.1")[0]["ip_prefix"]) == "52.94.76.0/22"
assert str(test_dig._lookup_data('52.94.76.0/24')[0]['ip_prefix']) == '52.94.76.0/22' assert str(test_dig._lookup_data("52.94.76.0/24")[0]["ip_prefix"]) == "52.94.76.0/22"
input = '2600:1f14:fff:f810:a1c1:f507:a2d1:2dd8' input = "2600:1f14:fff:f810:a1c1:f507:a2d1:2dd8"
assert str(test_dig._lookup_data(input)[0]['ipv6_prefix']) == '2600:1f14:fff:f800::/53' assert str(test_dig._lookup_data(input)[0]["ipv6_prefix"]) == "2600:1f14:fff:f800::/53"
assert str(test_dig._lookup_data(input)[1]['ipv6_prefix']) == '2600:1f14::/35' assert str(test_dig._lookup_data(input)[1]["ipv6_prefix"]) == "2600:1f14::/35"
assert str(test_dig._lookup_data('2600:1f14::/36')[0]['ipv6_prefix']) == '2600:1f14::/35' assert str(test_dig._lookup_data("2600:1f14::/36")[0]["ipv6_prefix"]) == "2600:1f14::/35"
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
test_dig.lookup('what are you talking about') test_dig.lookup("what are you talking about")
assert e.startswith('Wrong IP or CIDR format') assert e.startswith("Wrong IP or CIDR format")
def test_response_plain_print(test_dig, capsys): def test_response_plain_print(test_dig, capsys):
test_dig.lookup('52.94.76.0/22').plain_print() test_dig.lookup("52.94.76.0/22").plain_print()
out, _ = capsys.readouterr() out, _ = capsys.readouterr()
assert out == tests.RESPONSE_PLAIN_PRINT assert out == tests.RESPONSE_PLAIN_PRINT
def test_response_json_print(test_dig, capsys): def test_response_json_print(test_dig, capsys):
test_dig.lookup('2600:1f14:fff:f810:a1c1:f507:a2d1:2dd8').json_print() test_dig.lookup("2600:1f14:fff:f810:a1c1:f507:a2d1:2dd8").json_print()
out, _ = capsys.readouterr() out, _ = capsys.readouterr()
assert out == tests.RESPONSE_JSON_PRINT assert out == tests.RESPONSE_JSON_PRINT

View File

@ -1,12 +1,10 @@
import json import json
import os import os
import digaws.digaws as digaws
import pytest import pytest
import requests import requests
import digaws.digaws as digaws
import tests import tests
@ -28,61 +26,61 @@ def create_cache_dir(fs):
digaws.CACHE_DIR.mkdir(parents=True) digaws.CACHE_DIR.mkdir(parents=True)
@pytest.mark.parametrize('fs', [[None, [digaws]]], indirect=True) @pytest.mark.parametrize("fs", [[None, [digaws]]], indirect=True)
def test_get_aws_ip_ranges_cached_valid_file(mocker, fs, create_cache_dir) -> None: def test_get_aws_ip_ranges_cached_valid_file(mocker, fs, create_cache_dir) -> None:
with open(digaws.CACHE_FILE, 'w') as out: with open(digaws.CACHE_FILE, "w") as out:
out.write(tests.AWS_IP_RANGES) out.write(tests.AWS_IP_RANGES)
response = requests.Response response = requests.Response
response.status_code = 304 response.status_code = 304
mocker.patch('requests.get', return_value=response) mocker.patch("requests.get", return_value=response)
result = digaws.get_aws_ip_ranges() result = digaws.get_aws_ip_ranges()
assert result['syncToken'] == '1608245058' assert result["syncToken"] == "1608245058"
@pytest.mark.parametrize('fs', [[None, [digaws]]], indirect=True) @pytest.mark.parametrize("fs", [[None, [digaws]]], indirect=True)
def test_get_aws_ip_ranges_cached_invalid_file(mocker, fs, create_cache_dir) -> None: def test_get_aws_ip_ranges_cached_invalid_file(mocker, fs, create_cache_dir) -> None:
with open(digaws.CACHE_FILE, 'w'): with open(digaws.CACHE_FILE, "w"):
pass pass
response = requests.Response response = requests.Response
response.status_code = 304 response.status_code = 304
mocker.patch('requests.get', return_value=response) mocker.patch("requests.get", return_value=response)
with pytest.raises(digaws.CachedFileException): with pytest.raises(digaws.CachedFileException):
digaws.get_aws_ip_ranges() digaws.get_aws_ip_ranges()
@pytest.mark.parametrize('fs', [[None, [digaws]]], indirect=True) @pytest.mark.parametrize("fs", [[None, [digaws]]], indirect=True)
def test_get_aws_ip_ranges_cached_deprecated_file(monkeypatch, fs, create_cache_dir) -> None: def test_get_aws_ip_ranges_cached_deprecated_file(monkeypatch, fs, create_cache_dir) -> None:
with open(digaws.CACHE_FILE, 'w'): with open(digaws.CACHE_FILE, "w"):
pass pass
digaws.CACHE_FILE.touch() digaws.CACHE_FILE.touch()
os.utime(digaws.CACHE_FILE, times=(0, 0)) os.utime(digaws.CACHE_FILE, times=(0, 0))
monkeypatch.setattr(requests, 'get', mock_get) monkeypatch.setattr(requests, "get", mock_get)
result = digaws.get_aws_ip_ranges() result = digaws.get_aws_ip_ranges()
assert result['syncToken'] == '1608245058' assert result["syncToken"] == "1608245058"
@pytest.mark.parametrize('fs', [[None, [digaws]]], indirect=True) @pytest.mark.parametrize("fs", [[None, [digaws]]], indirect=True)
def test_get_aws_ip_ranges_no_file(monkeypatch, fs, create_cache_dir) -> None: def test_get_aws_ip_ranges_no_file(monkeypatch, fs, create_cache_dir) -> None:
monkeypatch.setattr(requests, 'get', mock_get) monkeypatch.setattr(requests, "get", mock_get)
result = digaws.get_aws_ip_ranges() result = digaws.get_aws_ip_ranges()
assert result['syncToken'] == '1608245058' assert result["syncToken"] == "1608245058"
@pytest.mark.parametrize('fs', [[None, [digaws]]], indirect=True) @pytest.mark.parametrize("fs", [[None, [digaws]]], indirect=True)
def test_get_aws_ip_ranges_invalid_status(mocker, fs, create_cache_dir) -> None: def test_get_aws_ip_ranges_invalid_status(mocker, fs, create_cache_dir) -> None:
response = requests.Response response = requests.Response
response.status_code = 301 response.status_code = 301
mocker.patch('requests.get', return_value=response) mocker.patch("requests.get", return_value=response)
with pytest.raises(digaws.UnexpectedRequestException) as e: with pytest.raises(digaws.UnexpectedRequestException) as e:
digaws.get_aws_ip_ranges() digaws.get_aws_ip_ranges()
assert e.match('^Unexpected response from') assert e.match("^Unexpected response from")