import logging from datetime import datetime, timedelta from dateutil.relativedelta import relativedelta from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker, declarative_base, Session, relationship from util import NV logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) Base = declarative_base() class Site(Base): __tablename__ = "site" INITIAL_SITE_KEY_XID = '00000000-0000-0000-0000-000000000000' INITIAL_SITE_NAME = 'default' site_key = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4, SITE_KEY_XID name = Column(VARCHAR(length=256), nullable=False) def __str__(self): return f'SITE_KEY_XID: {self.site_key}' @staticmethod def create_statement(engine: Engine): from sqlalchemy.schema import CreateTable return CreateTable(Site.__table__).compile(engine) @staticmethod def get_default_site(engine: Engine) -> "Site": session = sessionmaker(bind=engine)() entity = session.query(Site).filter(Site.site_key == Site.INITIAL_SITE_KEY_XID).first() session.close() return entity class Instance(Base): __tablename__ = "instance" DEFAULT_INSTANCE_REF = '10000000-0000-0000-0000-000000000001' DEFAULT_TOKEN_EXPIRE_DELTA = 86_400 # 1 day DEFAULT_LEASE_EXPIRE_DELTA = 7_776_000 # 90 days DEFAULT_LEASE_RENEWAL_PERIOD = 0.15 DEFAULT_CLIENT_TOKEN_EXPIRE_DELTA = 378_432_000 # 12 years # 1 day = 86400 (min. in production setup, max 90 days), 1 hour = 3600 instance_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4, INSTANCE_REF site_key = Column(CHAR(length=36), ForeignKey(Site.site_key, ondelete='CASCADE'), nullable=False, index=True) # uuid4 private_key = Column(BLOB(length=2048), nullable=False) public_key = Column(BLOB(length=512), nullable=False) token_expire_delta = Column(INT(), nullable=False, default=DEFAULT_TOKEN_EXPIRE_DELTA, comment='in seconds') lease_expire_delta = Column(INT(), nullable=False, default=DEFAULT_LEASE_EXPIRE_DELTA, comment='in seconds') lease_renewal_period = Column(FLOAT(precision=2), nullable=False, default=DEFAULT_LEASE_RENEWAL_PERIOD) client_token_expire_delta = Column(INT(), nullable=False, default=DEFAULT_CLIENT_TOKEN_EXPIRE_DELTA, comment='in seconds') __origin = relationship(Site, foreign_keys=[site_key]) def __str__(self): return f'INSTANCE_REF: {self.instance_ref} (SITE_KEY_XID: {self.site_key})' @staticmethod def create_statement(engine: Engine): from sqlalchemy.schema import CreateTable return CreateTable(Instance.__table__).compile(engine) @staticmethod def create_or_update(engine: Engine, instance: "Instance"): session = sessionmaker(bind=engine)() entity = session.query(Instance).filter(Instance.instance_ref == instance.instance_ref).first() if entity is None: session.add(instance) else: x = dict( site_key=instance.site_key, private_key=instance.private_key, public_key=instance.public_key, token_expire_delta=instance.token_expire_delta, lease_expire_delta=instance.lease_expire_delta, lease_renewal_period=instance.lease_renewal_period, client_token_expire_delta=instance.client_token_expire_delta, ) session.execute(update(Instance).where(Instance.instance_ref == instance.instance_ref).values(**x)) session.commit() session.flush() session.close() # todo: validate on startup that "lease_expire_delta" is between 1 day and 90 days @staticmethod def get_default_instance(engine: Engine) -> "Instance": session = sessionmaker(bind=engine)() site = Site.get_default_site(engine) entity = session.query(Instance).filter(Instance.site_key == site.site_key).first() session.close() return entity def get_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta": return relativedelta(seconds=self.token_expire_delta) def get_lease_expire_delta(self) -> "dateutil.relativedelta.relativedelta": return relativedelta(seconds=self.lease_expire_delta) def get_lease_renewal_delta(self) -> "datetime.timedelta": return timedelta(seconds=self.lease_expire_delta) def get_client_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta": return relativedelta(seconds=self.client_token_expire_delta) def __get_private_key(self) -> "RsaKey": return parse_key(self.private_key) def get_public_key(self) -> "RsaKey": return parse_key(self.public_key) def get_jwt_encode_key(self) -> "jose.jkw": from jose import jwk from jose.constants import ALGORITHMS return jwk.construct(self.__get_private_key().export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256) def get_jwt_decode_key(self) -> "jose.jwt": from jose import jwk from jose.constants import ALGORITHMS return jwk.construct(self.get_public_key().export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256) def get_private_key_str(self, encoding: str = 'utf-8') -> str: return self.private_key.decode(encoding) def get_public_key_str(self, encoding: str = 'utf-8') -> str: return self.private_key.decode(encoding) class Origin(Base): __tablename__ = "origin" origin_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4 # service_instance_xid = Column(CHAR(length=36), nullable=False, index=True) # uuid4 # not necessary, we only support one service_instance_xid ('INSTANCE_REF') hostname = Column(VARCHAR(length=256), nullable=True) guest_driver_version = Column(VARCHAR(length=10), nullable=True) os_platform = Column(VARCHAR(length=256), nullable=True) os_version = Column(VARCHAR(length=256), nullable=True) def __repr__(self): return f'Origin(origin_ref={self.origin_ref}, hostname={self.hostname})' def serialize(self) -> dict: _ = NV().find(self.guest_driver_version) return { 'origin_ref': self.origin_ref, # 'service_instance_xid': self.service_instance_xid, 'hostname': self.hostname, 'guest_driver_version': self.guest_driver_version, 'os_platform': self.os_platform, 'os_version': self.os_version, '$driver': _ if _ is not None else None, } @staticmethod def create_statement(engine: Engine): from sqlalchemy.schema import CreateTable return CreateTable(Origin.__table__).compile(engine) @staticmethod def create_or_update(engine: Engine, origin: "Origin"): session = sessionmaker(bind=engine)() entity = session.query(Origin).filter(Origin.origin_ref == origin.origin_ref).first() if entity is None: session.add(origin) else: x = dict( hostname=origin.hostname, guest_driver_version=origin.guest_driver_version, os_platform=origin.os_platform, os_version=origin.os_version ) session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(**x)) session.commit() session.flush() session.close() @staticmethod def delete(engine: Engine, origin_refs: [str] = None) -> int: session = sessionmaker(bind=engine)() if origin_refs is None: deletions = session.query(Origin).delete() else: deletions = session.query(Origin).filter(Origin.origin_ref in origin_refs).delete() session.commit() session.close() return deletions class Lease(Base): __tablename__ = "lease" instance_ref = Column(CHAR(length=36), ForeignKey(Instance.instance_ref, ondelete='CASCADE'), nullable=False, index=True) # uuid4 lease_ref = Column(CHAR(length=36), primary_key=True, nullable=False, index=True) # uuid4 origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref, ondelete='CASCADE'), nullable=False, index=True) # uuid4 # scope_ref = Column(CHAR(length=36), nullable=False, index=True) # uuid4 # not necessary, we only support one scope_ref ('ALLOTMENT_REF') lease_created = Column(DATETIME(), nullable=False) lease_expires = Column(DATETIME(), nullable=False) lease_updated = Column(DATETIME(), nullable=False) __instance = relationship(Instance, foreign_keys=[instance_ref]) __origin = relationship(Origin, foreign_keys=[origin_ref]) def __repr__(self): return f'Lease(origin_ref={self.origin_ref}, lease_ref={self.lease_ref}, expires={self.lease_expires})' def serialize(self) -> dict: renewal_period = self.__instance.lease_renewal_period renewal_delta = self.__instance.get_lease_renewal_delta lease_renewal = int(Lease.calculate_renewal(renewal_period, renewal_delta).total_seconds()) lease_renewal = self.lease_updated + relativedelta(seconds=lease_renewal) return { 'lease_ref': self.lease_ref, 'origin_ref': self.origin_ref, # 'scope_ref': self.scope_ref, 'lease_created': self.lease_created.isoformat(), 'lease_expires': self.lease_expires.isoformat(), 'lease_updated': self.lease_updated.isoformat(), 'lease_renewal': lease_renewal.isoformat(), } @staticmethod def create_statement(engine: Engine): from sqlalchemy.schema import CreateTable return CreateTable(Lease.__table__).compile(engine) @staticmethod def create_or_update(engine: Engine, lease: "Lease"): session = sessionmaker(bind=engine)() entity = session.query(Lease).filter(Lease.lease_ref == lease.lease_ref).first() if entity is None: if lease.lease_updated is None: lease.lease_updated = lease.lease_created session.add(lease) else: x = dict(origin_ref=lease.origin_ref, lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) session.execute(update(Lease).where(Lease.lease_ref == lease.lease_ref).values(**x)) session.commit() session.flush() session.close() @staticmethod def find_by_origin_ref(engine: Engine, origin_ref: str) -> ["Lease"]: session = sessionmaker(bind=engine)() entities = session.query(Lease).filter(Lease.origin_ref == origin_ref).all() session.close() return entities @staticmethod def find_by_lease_ref(engine: Engine, lease_ref: str) -> "Lease": session = sessionmaker(bind=engine)() entity = session.query(Lease).filter(Lease.lease_ref == lease_ref).first() session.close() return entity @staticmethod def find_by_origin_ref_and_lease_ref(engine: Engine, origin_ref: str, lease_ref: str) -> "Lease": session = sessionmaker(bind=engine)() entity = session.query(Lease).filter(and_(Lease.origin_ref == origin_ref, Lease.lease_ref == lease_ref)).first() session.close() return entity @staticmethod def renew(engine: Engine, lease: "Lease", lease_expires: datetime, lease_updated: datetime): session = sessionmaker(bind=engine)() x = dict(lease_expires=lease_expires, lease_updated=lease_updated) session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**x)) session.commit() session.close() @staticmethod def cleanup(engine: Engine, origin_ref: str) -> int: session = sessionmaker(bind=engine)() deletions = session.query(Lease).filter(Lease.origin_ref == origin_ref).delete() session.commit() session.close() return deletions @staticmethod def delete(engine: Engine, lease_ref: str) -> int: session = sessionmaker(bind=engine)() deletions = session.query(Lease).filter(Lease.lease_ref == lease_ref).delete() session.commit() session.close() return deletions @staticmethod def delete_expired(engine: Engine) -> int: session = sessionmaker(bind=engine)() deletions = session.query(Lease).filter(Lease.lease_expires <= datetime.utcnow()).delete() session.commit() session.close() return deletions @staticmethod def calculate_renewal(renewal_period: float, delta: timedelta) -> timedelta: """ import datetime LEASE_RENEWAL_PERIOD=0.2 # 20% delta = datetime.timedelta(days=1) renew = delta.total_seconds() * LEASE_RENEWAL_PERIOD renew = datetime.timedelta(seconds=renew) expires = delta - renew # 19.2 import datetime LEASE_RENEWAL_PERIOD=0.15 # 15% delta = datetime.timedelta(days=90) renew = delta.total_seconds() * LEASE_RENEWAL_PERIOD renew = datetime.timedelta(seconds=renew) expires = delta - renew # 76 days, 12:00:00 hours """ renew = delta.total_seconds() * renewal_period renew = timedelta(seconds=renew) return renew def init_default_site(session: Session): from uuid import uuid4 from app.util import generate_key private_key = generate_key() public_key = private_key.public_key() site = Site( site_key=Site.INITIAL_SITE_KEY_XID, name=Site.INITIAL_SITE_NAME ) session.add(site) session.commit() instance = Instance( instance_ref=Instance.DEFAULT_INSTANCE_REF, site_key=site.site_key, private_key=private_key.export_key(), public_key=public_key.export_key(), ) session.add(instance) session.commit() def init(engine: Engine): tables = [Site, Instance, Origin, Lease] db = inspect(engine) session = sessionmaker(bind=engine)() for table in tables: exists = db.dialect.has_table(engine.connect(), table.__tablename__) logger.info(f'> Table "{table.__tablename__:<16}" exists: {exists}') if not exists: session.execute(text(str(table.create_statement(engine)))) session.commit() # create default site cnt = session.query(Site).count() if cnt == 0: init_default_site(session) session.flush() session.close() def migrate(engine: Engine): from os import getenv as env from os.path import join, dirname, isfile from util import load_key db = inspect(engine) # todo: add update guide to use 1.LATEST to 2.0 def upgrade_1_x_to_2_0(): site = Site.get_default_site(engine) logger.info(site) instance = Instance.get_default_instance(engine) logger.info(instance) # SITE_KEY_XID if site_key := env('SITE_KEY_XID', None) is not None: site.site_key = str(site_key) # INSTANCE_REF if instance_ref := env('INSTANCE_REF', None) is not None: instance.instance_ref = str(instance_ref) # ALLOTMENT_REF if allotment_ref := env('ALLOTMENT_REF', None) is not None: pass # todo # INSTANCE_KEY_RSA, INSTANCE_KEY_PUB default_instance_private_key_path = str(join(dirname(__file__), 'cert/instance.private.pem')) if instance_private_key := env('INSTANCE_KEY_RSA', None) is not None: instance.private_key = load_key(str(instance_private_key)) elif isfile(default_instance_private_key_path): instance.private_key = load_key(default_instance_private_key_path) default_instance_public_key_path = str(join(dirname(__file__), 'cert/instance.public.pem')) if instance_public_key := env('INSTANCE_KEY_PUB', None) is not None: instance.public_key = load_key(str(instance_public_key)) elif isfile(default_instance_public_key_path): instance.public_key = load_key(default_instance_public_key_path) # TOKEN_EXPIRE_DELTA if token_expire_delta := env('TOKEN_EXPIRE_DAYS', None) not in (None, 0): instance.token_expire_delta = token_expire_delta * 86_400 if token_expire_delta := env('TOKEN_EXPIRE_HOURS', None) not in (None, 0): instance.token_expire_delta = token_expire_delta * 3_600 # LEASE_EXPIRE_DELTA, LEASE_RENEWAL_DELTA if lease_expire_delta := env('LEASE_EXPIRE_DAYS', None) not in (None, 0): instance.lease_expire_delta = lease_expire_delta * 86_400 if lease_expire_delta := env('LEASE_EXPIRE_HOURS', None) not in (None, 0): instance.lease_expire_delta = lease_expire_delta * 3_600 # LEASE_RENEWAL_PERIOD if lease_renewal_period := env('LEASE_RENEWAL_PERIOD', None) is not None: instance.lease_renewal_period = lease_renewal_period # todo: update site, instance upgrade_1_x_to_2_0()