Explorar el Código

Reinstate mypy (#10)

Casper van der Wel hace 1 año
padre
commit
23bac4516c

+ 7 - 0
.pre-commit-config.yaml

@@ -26,3 +26,10 @@ repos:
       - id: flake8
         # NB The "exclude" setting in setup.cfg is ignored by pre-commit
         exclude: 'migrations*|urls*|settings*'
+  - repo: https://github.com/pre-commit/mirrors-mypy
+    rev: 'v1.5.1'
+    hooks:
+    - id: mypy
+      exclude: tests
+      additional_dependencies:
+        - 'pydantic==2.*'

+ 2 - 0
CHANGES.md

@@ -10,6 +10,8 @@
 
 - Added S3Gateway.
 
+- Reinstate static type linter (mypy).
+
 
 ## 0.3.4 (2023-08-28)
 ---------------------

+ 0 - 5
clean_python/base/domain/context.py

@@ -46,11 +46,6 @@ class Context:
             "tenant_value", default=None
         )
 
-    def reset(self):
-        self._path_value.reset()
-        self._user_value.reset()
-        self._tenant_value.reset()
-
     @property
     def path(self) -> AnyUrl:
         return self._path_value.get()

+ 6 - 6
clean_python/base/domain/domain_event.py

@@ -1,8 +1,7 @@
 # (c) Nelen & Schuurmans
-
-from abc import ABC
 from typing import Awaitable
 from typing import Callable
+from typing import Type
 from typing import TypeVar
 
 import blinker
@@ -10,17 +9,18 @@ import blinker
 __all__ = ["DomainEvent"]
 
 
-TDomainEvent = TypeVar("TDomainEvent", bound="DomainEvent")
-TEventHandler = Callable[[TDomainEvent], Awaitable[None]]
+T = TypeVar("T", bound="DomainEvent")
 
 
-class DomainEvent(ABC):
+class DomainEvent:
     @classmethod
     def _signal(cls) -> blinker.Signal:
         return blinker.signal(cls.__name__)
 
     @classmethod
-    def register_handler(cls, receiver: TEventHandler) -> TEventHandler:
+    def register_handler(
+        cls: Type[T], receiver: Callable[[T], Awaitable[None]]
+    ) -> Callable[[T], Awaitable[None]]:
         return cls._signal().connect(receiver)
 
     async def send_async(self) -> None:

+ 5 - 5
clean_python/base/domain/exceptions.py

@@ -62,19 +62,19 @@ class BadRequest(Exception):
         self._internal_error = err_or_msg
         super().__init__(err_or_msg)
 
-    def errors(self) -> List[Dict]:
+    def errors(self) -> List[Dict[str, Any]]:
         if isinstance(self._internal_error, ValidationError):
-            return self._internal_error.errors()
+            return [dict() for x in self._internal_error.errors()]
         return [{"error": self}]
 
     def __str__(self) -> str:
         error = self._internal_error
         if isinstance(error, ValidationError):
-            error = error.errors()[0]
-            loc = "'" + ",".join([str(x) for x in error["loc"]]) + "' "
+            details = error.errors()[0]
+            loc = "'" + ",".join([str(x) for x in details["loc"]]) + "' "
             if loc == "'*' ":
                 loc = ""
-            return f"validation error: {loc}{error['msg']}"
+            return f"validation error: {loc}{details['msg']}"
         return f"validation error: {super().__str__()}"
 
 

+ 1 - 1
clean_python/base/infrastructure/tmpdir_provider.py

@@ -10,5 +10,5 @@ class TmpDirProvider:
     def __init__(self, dir: Optional[str] = None):
         self.dir = dir
 
-    def __call__(self) -> TemporaryDirectory:
+    def __call__(self) -> TemporaryDirectory:  # type: ignore
         return TemporaryDirectory(dir=self.dir)

+ 1 - 1
clean_python/dramatiq/async_actor.py

@@ -171,7 +171,7 @@ class AsyncActor(dramatiq.Actor):
     def send_async_with_options(
         self,
         *,
-        args: tuple = (),
+        args: tuple = (),  # type: ignore
         kwargs: Optional[Dict[str, Any]] = None,
         delay: Optional[int] = None,
         **options,

+ 1 - 1
clean_python/fastapi/request_query.py

@@ -23,7 +23,7 @@ class RequestQuery(ValueObject):
     def validate_order_by_enum(cls, v, _):
         # the 'enum' parameter doesn't actually do anthing in validation
         # See: https://github.com/tiangolo/fastapi/issues/2910
-        allowed = cls.model_fields["order_by"].json_schema_extra["enum"]
+        allowed = cls.model_json_schema()["properties"]["order_by"]["enum"]
         if v not in allowed:
             raise ValueError(f"'order_by' must be one of {allowed}")
         return v

+ 3 - 2
clean_python/fastapi/service.py

@@ -2,6 +2,7 @@
 
 from typing import Any
 from typing import Callable
+from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Set
@@ -12,6 +13,7 @@ from fastapi import Request
 from fastapi.exceptions import RequestValidationError
 from starlette.types import ASGIApp
 
+from clean_python import BadRequest
 from clean_python import Conflict
 from clean_python import ctx
 from clean_python import DoesNotExist
@@ -22,7 +24,6 @@ from clean_python.oauth2 import OAuth2SPAClientSettings
 from clean_python.oauth2 import Token
 from clean_python.oauth2 import TokenVerifierSettings
 
-from .error_responses import BadRequest
 from .error_responses import conflict_handler
 from .error_responses import DefaultErrorResponse
 from .error_responses import not_found_handler
@@ -43,7 +44,7 @@ from .security import set_verifier
 __all__ = ["Service"]
 
 
-def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> None:
+def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> Dict[str, Any]:
     if auth_client is None:
         return {
             "dependencies": [Depends(JWTBearerTokenSchema()), Depends(set_context)],

+ 2 - 2
clean_python/oauth2/token.py

@@ -27,7 +27,7 @@ class Token(ValueObject):
 
     @property
     def user(self) -> User:
-        return User(id=self.claims["sub"], name=self.claims.get("username"))
+        return User(id=self.claims["sub"], name=self.claims["username"])
 
     @property
     def scope(self) -> Scope:
@@ -36,6 +36,6 @@ class Token(ValueObject):
     @property
     def tenant(self) -> Optional[Tenant]:
         if self.claims.get("tenant"):
-            return Tenant(id=self.claims["tenant"], name=self.claims.get("tenant_name"))
+            return Tenant(id=self.claims["tenant"], name=self.claims["tenant_name"])
         else:
             return None

+ 6 - 5
clean_python/oauth2/token_verifier.py

@@ -1,6 +1,7 @@
 # (c) Nelen & Schuurmans
 
 import logging
+from typing import Any
 from typing import Dict
 from typing import FrozenSet
 from typing import List
@@ -90,22 +91,22 @@ class TokenVerifier(BaseTokenVerifier):
         if authorization is None:
             logger.info("Missing Authorization header")
             raise Unauthorized()
-        token = authorization[7:] if authorization.startswith("Bearer") else None
-        if token is None:
+        jwt_str = authorization[7:] if authorization.startswith("Bearer") else None
+        if jwt_str is None:
             logger.info("Authorization does not start with 'Bearer '")
             raise Unauthorized()
 
         # Step 1: Confirm the structure of the JWT. This check is part of get_kid since
         # jwt.get_unverified_header will raise a JWTError if the structure is wrong.
         try:
-            key = self.get_key(token)  # JSON Web Key
+            key = self.get_key(jwt_str)  # JSON Web Key
         except PyJWTError as e:
             logger.info("Token is invalid: %s", e)
             raise Unauthorized()
         # Step 2: Validate the JWT signature and standard claims
         try:
             claims = jwt.decode(
-                token,
+                jwt_str,
                 key.key,
                 algorithms=self.settings.algorithms,
                 issuer=self.settings.issuer,
@@ -134,7 +135,7 @@ class TokenVerifier(BaseTokenVerifier):
         """Return the JSON Web KEY (JWK) corresponding to kid."""
         return self.jwk_client.get_signing_key_from_jwt(token)
 
-    def verify_token_use(self, claims: Dict) -> None:
+    def verify_token_use(self, claims: Dict[str, Any]) -> None:
         """Check the token_use claim."""
         if claims["token_use"] != "access":
             logger.info("Token has invalid token_use claim: %s", claims["token_use"])

+ 3 - 1
clean_python/testing/attr_dict.py

@@ -1,9 +1,11 @@
 # (c) Nelen & Schuurmans
+from typing import Any
+from typing import Dict
 
 __all__ = ["AttrDict"]
 
 
-class AttrDict(dict):
+class AttrDict(Dict[str, Any]):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.__dict__ = self

+ 18 - 0
mypy.ini

@@ -0,0 +1,18 @@
+[mypy]
+plugins = pydantic.mypy
+ignore_missing_imports = True
+exclude = /(*tests*)
+follow_imports = silent
+warn_redundant_casts = True
+warn_unused_ignores = True
+disallow_any_generics = True
+check_untyped_defs = True
+no_implicit_reexport = True
+# to enforce typing everywhere:
+# disallow_untyped_defs = True
+
+[pydantic-mypy]
+init_forbid_extra = True
+init_typed = True
+warn_required_dynamic_aliases = True
+warn_untyped_fields = True