code refactorings after merge from main

This commit is contained in:
Oscar Krause
2025-04-08 13:52:09 +02:00
parent f62f2a2701
commit 20cdaefa1c
3 changed files with 37 additions and 40 deletions

View File

@@ -1,11 +1,17 @@
import logging
from datetime import datetime, timedelta, timezone, UTC
from os import getenv as env
from os.path import join, dirname, isfile
from dateutil.relativedelta import relativedelta
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text
from jose import jwk
from jose.constants import ALGORITHMS
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text, BLOB, INT, FLOAT
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, declarative_base, Session, relationship
from sqlalchemy.schema import CreateTable
from util import NV
from util import NV, PrivateKey, PublicKey
logging.basicConfig()
logger = logging.getLogger(__name__)
@@ -28,7 +34,6 @@ class Site(Base):
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Site.__table__).compile(engine)
@staticmethod
@@ -65,7 +70,6 @@ class Instance(Base):
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Instance.__table__).compile(engine)
@staticmethod
@@ -111,21 +115,18 @@ class Instance(Base):
def get_client_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
return relativedelta(seconds=self.client_token_expire_delta)
def __get_private_key(self) -> "RsaKey":
return parse_key(self.private_key)
def __get_private_key(self) -> "PrivateKey":
return PrivateKey(self.private_key)
def get_public_key(self) -> "RsaKey":
return parse_key(self.public_key)
def get_public_key(self) -> "PublicKey":
return PublicKey(self.public_key)
def get_jwt_encode_key(self) -> "jose.jkw":
from jose import jwk
from jose.constants import ALGORITHMS
return jwk.construct(self.__get_private_key().export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
return jwk.construct(self.__get_private_key().pem().decode('utf-8'), algorithm=ALGORITHMS.RS256)
def get_jwt_decode_key(self) -> "jose.jwt":
from jose import jwk
from jose.constants import ALGORITHMS
return jwk.construct(self.get_public_key().export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
return jwk.construct(self.get_public_key().pem().decode('utf-8'), algorithm=ALGORITHMS.RS256)
def get_private_key_str(self, encoding: str = 'utf-8') -> str:
return self.private_key.decode(encoding)
@@ -162,7 +163,6 @@ class Origin(Base):
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Origin.__table__).compile(engine)
@staticmethod
@@ -241,7 +241,6 @@ class Lease(Base):
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Lease.__table__).compile(engine)
@staticmethod
@@ -336,9 +335,7 @@ class Lease(Base):
def init_default_site(session: Session):
from app.util import generate_key
private_key = generate_key()
private_key = PrivateKey.generate()
public_key = private_key.public_key()
site = Site(
@@ -351,8 +348,8 @@ def init_default_site(session: Session):
instance = Instance(
instance_ref=Instance.DEFAULT_INSTANCE_REF,
site_key=site.site_key,
private_key=private_key.export_key(),
public_key=public_key.export_key(),
private_key=private_key.pem(),
public_key=public_key.pem(),
)
session.add(instance)
session.commit()
@@ -379,10 +376,6 @@ def init(engine: Engine):
def migrate(engine: Engine):
from os import getenv as env
from os.path import join, dirname, isfile
from util import load_key
db = inspect(engine)
# todo: add update guide to use 1.LATEST to 2.0
@@ -408,15 +401,15 @@ def migrate(engine: Engine):
default_instance_private_key_path = str(join(dirname(__file__), 'cert/instance.private.pem'))
instance_private_key = env('INSTANCE_KEY_RSA', None)
if instance_private_key is not None:
instance.private_key = load_key(str(instance_private_key))
instance.private_key = PrivateKey(instance_private_key.encode('utf-8'))
elif isfile(default_instance_private_key_path):
instance.private_key = load_key(default_instance_private_key_path)
instance.private_key = PrivateKey.from_file(default_instance_private_key_path)
default_instance_public_key_path = str(join(dirname(__file__), 'cert/instance.public.pem'))
instance_public_key = env('INSTANCE_KEY_PUB', None)
if instance_public_key is not None:
instance.public_key = load_key(str(instance_public_key))
instance.public_key = PublicKey(instance_public_key.encode('utf-8'))
elif isfile(default_instance_public_key_path):
instance.public_key = load_key(default_instance_public_key_path)
instance.public_key = PublicKey.from_file(default_instance_public_key_path)
# TOKEN_EXPIRE_DELTA
token_expire_delta = env('TOKEN_EXPIRE_DAYS', None)