Address Q000 in oidc-exchange.py

This commit is contained in:
Sviatoslav Sydorenko (Святослав Сидоренко) 2024-05-16 17:30:39 +02:00 committed by GitHub
parent 5569480d08
commit 9da6dedb16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -109,26 +109,26 @@ a few minutes and try again.
def die(msg: str) -> NoReturn: def die(msg: str) -> NoReturn:
with _GITHUB_STEP_SUMMARY.open("a", encoding="utf-8") as io: with _GITHUB_STEP_SUMMARY.open('a', encoding='utf-8') as io:
print(_ERROR_SUMMARY_MESSAGE.format(message=msg), file=io) print(_ERROR_SUMMARY_MESSAGE.format(message=msg), file=io)
# HACK: GitHub Actions' annotations don't work across multiple lines naively; # HACK: GitHub Actions' annotations don't work across multiple lines naively;
# translating `\n` into `%0A` (i.e., HTML percent-encoding) is known to work. # translating `\n` into `%0A` (i.e., HTML percent-encoding) is known to work.
# See: https://github.com/actions/toolkit/issues/193 # See: https://github.com/actions/toolkit/issues/193
msg = msg.replace("\n", "%0A") msg = msg.replace('\n', '%0A')
print(f"::error::Trusted publishing exchange failure: {msg}", file=sys.stderr) print(f'::error::Trusted publishing exchange failure: {msg}', file=sys.stderr)
sys.exit(1) sys.exit(1)
def debug(msg: str): def debug(msg: str):
print(f"::debug::{msg.title()}", file=sys.stderr) print(f'::debug::{msg.title()}', file=sys.stderr)
def get_normalized_input(name: str) -> str | None: def get_normalized_input(name: str) -> str | None:
name = f"INPUT_{name.upper()}" name = f'INPUT_{name.upper()}'
if val := os.getenv(name): if val := os.getenv(name):
return val return val
return os.getenv(name.replace("-", "_")) return os.getenv(name.replace('-', '_'))
def assert_successful_audience_call(resp: requests.Response, domain: str): def assert_successful_audience_call(resp: requests.Response, domain: str):
@ -140,13 +140,13 @@ def assert_successful_audience_call(resp: requests.Response, domain: str):
# This index supports OIDC, but forbids the client from using # This index supports OIDC, but forbids the client from using
# it (either because it's disabled, ratelimited, etc.) # it (either because it's disabled, ratelimited, etc.)
die( die(
f"audience retrieval failed: repository at {domain} has trusted publishing disabled", f'audience retrieval failed: repository at {domain} has trusted publishing disabled',
) )
case HTTPStatus.NOT_FOUND: case HTTPStatus.NOT_FOUND:
# This index does not support OIDC. # This index does not support OIDC.
die( die(
"audience retrieval failed: repository at " 'audience retrieval failed: repository at '
f"{domain} does not indicate trusted publishing support", f'{domain} does not indicate trusted publishing support',
) )
case other: case other:
status = HTTPStatus(other) status = HTTPStatus(other)
@ -154,67 +154,67 @@ def assert_successful_audience_call(resp: requests.Response, domain: str):
# something we expect. This can happen if the index is broken, in maintenance mode, # something we expect. This can happen if the index is broken, in maintenance mode,
# misconfigured, etc. # misconfigured, etc.
die( die(
"audience retrieval failed: repository at " 'audience retrieval failed: repository at '
f"{domain} responded with unexpected {other}: {status.phrase}", f'{domain} responded with unexpected {other}: {status.phrase}',
) )
def render_claims(token: str) -> str: def render_claims(token: str) -> str:
_, payload, _ = token.split(".", 2) _, payload, _ = token.split('.', 2)
# urlsafe_b64decode needs padding; JWT payloads don't contain any. # urlsafe_b64decode needs padding; JWT payloads don't contain any.
payload += "=" * (4 - (len(payload) % 4)) payload += "=" * (4 - (len(payload) % 4))
claims = json.loads(base64.urlsafe_b64decode(payload)) claims = json.loads(base64.urlsafe_b64decode(payload))
def _get(name: str) -> str: # noqa: WPS430 def _get(name: str) -> str: # noqa: WPS430
return claims.get(name, "MISSING") return claims.get(name, 'MISSING')
return _RENDERED_CLAIMS.format( return _RENDERED_CLAIMS.format(
sub=_get("sub"), sub=_get('sub'),
repository=_get("repository"), repository=_get('repository'),
repository_owner=_get("repository_owner"), repository_owner=_get('repository_owner'),
repository_owner_id=_get("repository_owner_id"), repository_owner_id=_get('repository_owner_id'),
job_workflow_ref=_get("job_workflow_ref"), job_workflow_ref=_get('job_workflow_ref'),
ref=_get("ref"), ref=_get('ref'),
) )
def event_is_third_party_pr() -> bool: def event_is_third_party_pr() -> bool:
# Non-`pull_request` events cannot be from third-party PRs. # Non-`pull_request` events cannot be from third-party PRs.
if os.getenv("GITHUB_EVENT_NAME") != "pull_request": if os.getenv('GITHUB_EVENT_NAME') != 'pull_request':
return False return False
event_path = os.getenv("GITHUB_EVENT_PATH") event_path = os.getenv('GITHUB_EVENT_PATH')
if not event_path: if not event_path:
# No GITHUB_EVENT_PATH indicates a weird GitHub or runner bug. # No GITHUB_EVENT_PATH indicates a weird GitHub or runner bug.
debug("unexpected: no GITHUB_EVENT_PATH to check") debug('unexpected: no GITHUB_EVENT_PATH to check')
return False return False
try: try:
event = json.loads(Path(event_path).read_bytes()) event = json.loads(Path(event_path).read_bytes())
except json.JSONDecodeError: except json.JSONDecodeError:
debug("unexpected: GITHUB_EVENT_PATH does not contain valid JSON") debug('unexpected: GITHUB_EVENT_PATH does not contain valid JSON')
return False return False
try: try:
return event["pull_request"]["head"]["repo"]["fork"] return event['pull_request']['head']['repo']['fork']
except KeyError: except KeyError:
return False return False
repository_url = get_normalized_input("repository-url") repository_url = get_normalized_input('repository-url')
repository_domain = urlparse(repository_url).netloc repository_domain = urlparse(repository_url).netloc
token_exchange_url = f"https://{repository_domain}/_/oidc/mint-token" token_exchange_url = f'https://{repository_domain}/_/oidc/mint-token'
# Indices are expected to support `https://{domain}/_/oidc/audience`, # Indices are expected to support `https://{domain}/_/oidc/audience`,
# which tells OIDC exchange clients which audience to use. # which tells OIDC exchange clients which audience to use.
audience_url = f"https://{repository_domain}/_/oidc/audience" audience_url = f'https://{repository_domain}/_/oidc/audience'
audience_resp = requests.get(audience_url) audience_resp = requests.get(audience_url)
assert_successful_audience_call(audience_resp, repository_domain) assert_successful_audience_call(audience_resp, repository_domain)
oidc_audience = audience_resp.json()["audience"] oidc_audience = audience_resp.json()['audience']
debug(f"selected trusted publishing exchange endpoint: {token_exchange_url}") debug(f'selected trusted publishing exchange endpoint: {token_exchange_url}')
try: try:
oidc_token = id.detect_credential(audience=oidc_audience) oidc_token = id.detect_credential(audience=oidc_audience)
@ -229,7 +229,7 @@ except id.IdentityError as identity_error:
# Now we can do the actual token exchange. # Now we can do the actual token exchange.
mint_token_resp = requests.post( mint_token_resp = requests.post(
token_exchange_url, token_exchange_url,
json={"token": oidc_token}, json={'token': oidc_token},
) )
try: try:
@ -246,9 +246,9 @@ except requests.JSONDecodeError:
# On failure, the JSON response includes the list of errors that # On failure, the JSON response includes the list of errors that
# occurred during minting. # occurred during minting.
if not mint_token_resp.ok: if not mint_token_resp.ok:
reasons = "\n".join( reasons = '\n'.join(
f"* `{error['code']}`: {error['description']}" f'* `{error['code']}`: {error['description']}'
for error in mint_token_payload["errors"] for error in mint_token_payload['errors']
) )
rendered_claims = render_claims(oidc_token) rendered_claims = render_claims(oidc_token)
@ -260,12 +260,12 @@ if not mint_token_resp.ok:
), ),
) )
pypi_token = mint_token_payload.get("token") pypi_token = mint_token_payload.get('token')
if pypi_token is None: if pypi_token is None:
die(_SERVER_TOKEN_RESPONSE_MALFORMED_MESSAGE) die(_SERVER_TOKEN_RESPONSE_MALFORMED_MESSAGE)
# Mask the newly minted PyPI token, so that we don't accidentally leak it in logs. # Mask the newly minted PyPI token, so that we don't accidentally leak it in logs.
print(f"::add-mask::{pypi_token}", file=sys.stderr) print(f'::add-mask::{pypi_token}', file=sys.stderr)
# This final print will be captured by the subshell in `twine-upload.sh`. # This final print will be captured by the subshell in `twine-upload.sh`.
print(pypi_token) print(pypi_token)