Skip to content

Commit

Permalink
feat: added method to show correct RP organization_name in OP pages (#…
Browse files Browse the repository at this point in the history
…305)

* feat: added get_client_organisation_name method to retrieve the correct RP name

* chore: fix CIE organization_name

* fix: updated cryptography rsa import to 42.0.2

* chore: bump to 1.3.1

* fix: corrected proposed change

* fix: scope issue

* Update spid_cie_oidc/provider/views/consent_page_view.py

Co-authored-by: Giuseppe De Marco <[email protected]>

* Update spid_cie_oidc/provider/views/__init__.py

Co-authored-by: Giuseppe De Marco <[email protected]>

* Update spid_cie_oidc/provider/views/authz_request_view.py

Co-authored-by: Giuseppe De Marco <[email protected]>

* fix: reinstated method name

---------

Co-authored-by: Giuseppe De Marco <[email protected]>
  • Loading branch information
rglauco and Giuseppe De Marco authored Feb 7, 2024
1 parent 1faa95e commit 5355473
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 42 deletions.
2 changes: 1 addition & 1 deletion examples/provider/dumps/example.json
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
"metadata": {
"federation_entity": {
"federation_resolve_endpoint": "http://127.0.0.1:8002/oidc/op/resolve",
"organization_name": "SPID OIDC identity provider",
"organization_name": "CIE OIDC identity provider",
"homepage_uri": "http://127.0.0.1:8002",
"policy_uri": "http://127.0.0.1:8002/oidc/op/en/website/legal-information",
"logo_uri": "http://127.0.0.1:8002/static/svg/logo-cie.svg",
Expand Down
2 changes: 1 addition & 1 deletion spid_cie_oidc/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.3.0"
__version__ = "1.3.1"
6 changes: 3 additions & 3 deletions spid_cie_oidc/entity/jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from cryptojwt.jwk.rsa import new_rsa_key
from cryptography.hazmat.primitives import serialization
from cryptojwt.jwk.rsa import RSAKey

from cryptography.hazmat.primitives.asymmetric import rsa

import cryptography
from django.conf import settings
Expand Down Expand Up @@ -64,9 +64,9 @@ def serialize_rsa_key(rsa_key, kind="public", hash_func="SHA-256"):
cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey
"""
data = {}
if isinstance(rsa_key, cryptography.hazmat.backends.openssl.rsa._RSAPublicKey):
if isinstance(rsa_key, rsa.RSAPublicKey):
data = {"pub_key": rsa_key}
elif isinstance(rsa_key, cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey):
elif isinstance(rsa_key, rsa.RSAPrivateKey):
data = {"priv_key": rsa_key}
elif isinstance(rsa_key, (str, bytes)): # pragma: no cover
if kind == "private":
Expand Down
80 changes: 49 additions & 31 deletions spid_cie_oidc/provider/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
OIDCFED_PROVIDER_PROFILES_ACR_4_REFRESH,
OIDCFED_PROVIDER_PROFILES_ID_TOKEN_CLAIMS
)

logger = logging.getLogger(__name__)


Expand All @@ -40,7 +41,7 @@ class OpBase:
Baseclass with common methods for OPs
"""

def redirect_response_data(self, redirect_uri:str, **kwargs) -> HttpResponseRedirect:
def redirect_response_data(self, redirect_uri: str, **kwargs) -> HttpResponseRedirect:
if "?" in redirect_uri:
qstring = "&"
else:
Expand Down Expand Up @@ -114,7 +115,7 @@ def validate_authz_request_object(self, req) -> TrustChain:

jwks = get_jwks(
rp_trust_chain.metadata['openid_relying_party'],
federation_jwks = rp_trust_chain.jwks
federation_jwks=rp_trust_chain.jwks
)
jwk = self.find_jwk(header, jwks)
if not jwk:
Expand Down Expand Up @@ -178,7 +179,7 @@ def check_session(self, request) -> OidcSession:
)

session_not_after = session.created + timezone.timedelta(
minutes = OIDCFED_PROVIDER_AUTH_CODE_MAX_AGE
minutes=OIDCFED_PROVIDER_AUTH_CODE_MAX_AGE
)
if session_not_after < timezone.localtime():
raise ExpiredAuthCode(
Expand All @@ -199,12 +200,12 @@ def check_client_assertion(self, client_id: str, client_assertion: str) -> bool:
_op = self.get_issuer()
_op_eid = _op.sub
_op_eid_authz_endpoint = [_op.metadata['openid_provider']['authorization_endpoint']]

try:
ClientAssertion(**payload)
except Exception as e:
raise Exception(f"Client Assertion: json schema validation error: {e}")

if isinstance(_aud, str):
_aud = [_aud]
_allowed_auds = _aud + _op_eid_authz_endpoint
Expand Down Expand Up @@ -250,9 +251,9 @@ def get_jwt_common_data(self):
}

def get_access_token(
self, iss_sub:str, sub:str, authz: OidcSession, commons:dict
self, iss_sub: str, sub: str, authz: OidcSession, commons: dict
) -> dict:

access_token = {
"iss": iss_sub,
"sub": sub,
Expand All @@ -266,8 +267,8 @@ def get_access_token(
return access_token

def get_id_token_claims(
self,
authz:OidcSession
self,
authz: OidcSession
) -> dict:
_provider_profile = getattr(settings, 'OIDCFED_DEFAULT_PROVIDER_PROFILE', OIDCFED_DEFAULT_PROVIDER_PROFILE)
claims = {}
Expand All @@ -276,21 +277,21 @@ def get_id_token_claims(
return claims

for claim in (
authz.authz_request.get(
"claims", {}
).get("id_token", {}).keys()
authz.authz_request.get(
"claims", {}
).get("id_token", {}).keys()
):
if claim in allowed_id_token_claims and authz.user.attributes.get(claim, None):
claims[claim] = authz.user.attributes[claim]
return claims

def get_id_token(
self,
iss_sub:str,
sub:str,
authz:OidcSession,
jwt_at:str,
commons:dict
self,
iss_sub: str,
sub: str,
authz: OidcSession,
jwt_at: str,
commons: dict
) -> dict:

id_token = {
Expand All @@ -312,19 +313,19 @@ def get_id_token(

def get_refresh_token(
self,
iss_sub:str,
sub:str,
authz:OidcSession,
jwt_at:str,
commons:dict
iss_sub: str,
sub: str,
authz: OidcSession,
jwt_at: str,
commons: dict
) -> dict:
# refresh token is scope offline_access and prompt == consent
refresh_acrs = OIDCFED_PROVIDER_PROFILES_ACR_4_REFRESH[OIDCFED_DEFAULT_PROVIDER_PROFILE]
acrs = authz.authz_request.get('acr_values', [])
if (
"offline_access" in authz.authz_request['scope'] and
'consent' in authz.authz_request['prompt'] and
set(refresh_acrs).intersection(set(acrs))
"offline_access" in authz.authz_request['scope'] and
'consent' in authz.authz_request['prompt'] and
set(refresh_acrs).intersection(set(acrs))
):
refresh_token = {
"sub": sub,
Expand All @@ -337,8 +338,8 @@ def get_refresh_token(
refresh_token.update(commons)
return refresh_token

def get_iss_token_data(self, session : OidcSession, issuer: FederationEntityConfiguration):
_sub = session.pairwised_sub(provider_id = issuer.sub)
def get_iss_token_data(self, session: OidcSession, issuer: FederationEntityConfiguration):
_sub = session.pairwised_sub(provider_id=issuer.sub)
iss_sub = issuer.sub
commons = self.get_jwt_common_data()
jwk = issuer.jwks_core[0]
Expand All @@ -363,7 +364,7 @@ def get_iss_token_data(self, session : OidcSession, issuer: FederationEntityConf

def get_expires_in(self, iat: int, exp: int):
return timezone.timedelta(
seconds = exp - iat
seconds=exp - iat
).seconds

def attributes_names_to_release(self, request, session: OidcSession) -> dict:
Expand Down Expand Up @@ -391,6 +392,23 @@ def attributes_names_to_release(self, request, session: OidcSession) -> dict:
for i in filtered_user_claims.keys()
]
return dict(
i18n_user_claims = i18n_user_claims,
filtered_user_claims = filtered_user_claims
i18n_user_claims=i18n_user_claims,
filtered_user_claims=filtered_user_claims
)

def get_client_organization_name(self, tc):
rp_metadata = (
tc.metadata.get(
"federation_entity", {}
) or
tc.metadata.get(
"openid_relying_party", {}
)
)
if rp_metadata:
name = (
rp_metadata.get("organization_name", "") or
rp_metadata.get("client_name", "") or
rp_metadata.get("client_id", "")
)
return name
4 changes: 1 addition & 3 deletions spid_cie_oidc/provider/views/authz_request_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ def get(self, request, *args, **kwargs):
# stores the authz request in a hidden field in the form
form = self.get_login_form()()
context = {
"client_organization_name": tc.metadata.get(
"client_name", self.payload["client_id"]
),
"client_organization_name": self.get_client_organization_name(tc),
"hidden_form": AuthzHiddenForm(dict(authz_request_object=req)),
"form": form,
"redirect_uri": self.payload["redirect_uri"],
Expand Down
4 changes: 1 addition & 3 deletions spid_cie_oidc/provider/views/consent_page_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def get(self, request, *args, **kwargs):
context = {
"form": self.get_consent_form()(),
"session": session,
"client_organization_name": tc.metadata.get(
"client_name", session.client_id
),
"client_organization_name": self.get_client_organization_name(tc),
"user_claims": sorted(set(i18n_user_claims),),
"redirect_uri": session.authz_request["redirect_uri"],
"state": session.authz_request["state"]
Expand Down

0 comments on commit 5355473

Please sign in to comment.