mirror of
https://gitea.publichub.eu/oscar.krause/fastapi-dls.git
synced 2025-11-03 12:26:11 +00:00
Merge branch 'refs/heads/dev' into db
# Conflicts: # .gitlab-ci.yml # Dockerfile # README.md # app/main.py # app/orm.py # requirements.txt
This commit is contained in:
111
app/main.py
111
app/main.py
@@ -1,50 +1,80 @@
|
||||
import logging
|
||||
from base64 import b64encode as b64enc
|
||||
from calendar import timegm
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from hashlib import sha256
|
||||
from uuid import uuid4
|
||||
from os.path import join, dirname
|
||||
from json import loads as json_loads
|
||||
from os import getenv as env
|
||||
from os.path import join, dirname
|
||||
from uuid import uuid4
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from fastapi.requests import Request
|
||||
from json import loads as json_loads
|
||||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from calendar import timegm
|
||||
from jose import jws, jwt, JWTError
|
||||
from jose import jws, jwk, jwt, JWTError
|
||||
from jose.constants import ALGORITHMS
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
|
||||
|
||||
from orm import init as db_init, migrate, Site, Instance, Origin, Lease
|
||||
|
||||
# Load variables
|
||||
load_dotenv('../version.env')
|
||||
|
||||
# get local timezone
|
||||
# Get current timezone
|
||||
TZ = datetime.now().astimezone().tzinfo
|
||||
|
||||
# fetch version info
|
||||
# Load basic variables
|
||||
VERSION, COMMIT, DEBUG = env('VERSION', 'unknown'), env('COMMIT', 'unknown'), bool(env('DEBUG', False))
|
||||
|
||||
# fastapi setup
|
||||
config = dict(openapi_url='/-/openapi.json', docs_url=None, redoc_url=None)
|
||||
app = FastAPI(title='FastAPI-DLS', description='Minimal Delegated License Service (DLS).', version=VERSION, **config)
|
||||
|
||||
# database setup
|
||||
# Database connection
|
||||
db = create_engine(str(env('DATABASE', 'sqlite:///db.sqlite')))
|
||||
db_init(db), migrate(db)
|
||||
|
||||
# DLS setup (static)
|
||||
# Load DLS variables (all prefixed with "INSTANCE_*" is used as "SERVICE_INSTANCE_*" or "SI_*" in official dls service)
|
||||
DLS_URL = str(env('DLS_URL', 'localhost'))
|
||||
DLS_PORT = int(env('DLS_PORT', '443'))
|
||||
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
|
||||
|
||||
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001')) # todo
|
||||
|
||||
# fastapi middleware
|
||||
|
||||
# FastAPI
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
# on startup
|
||||
default_instance = Instance.get_default_instance(db)
|
||||
|
||||
lease_renewal_period = default_instance.lease_renewal_period
|
||||
lease_renewal_delta = default_instance.get_lease_renewal_delta()
|
||||
client_token_expire_delta = default_instance.get_client_token_expire_delta()
|
||||
|
||||
logger.info(f'''
|
||||
Using timezone: {str(TZ)}. Make sure this is correct and match your clients!
|
||||
|
||||
Your clients will renew their license every {str(Lease.calculate_renewal(lease_renewal_period, lease_renewal_delta))}.
|
||||
If the renewal fails, the license is valid for {str(lease_renewal_delta)}.
|
||||
|
||||
Your client-token file (.tok) is valid for {str(client_token_expire_delta)}.
|
||||
''')
|
||||
|
||||
logger.info(f'Debug is {"enabled" if DEBUG else "disabled"}.')
|
||||
|
||||
validate_settings()
|
||||
|
||||
yield
|
||||
|
||||
# on shutdown
|
||||
logger.info(f'Shutting down ...')
|
||||
|
||||
|
||||
config = dict(openapi_url=None, docs_url=None, redoc_url=None) # dict(openapi_url='/-/openapi.json', docs_url='/-/docs', redoc_url='/-/redoc')
|
||||
app = FastAPI(title='FastAPI-DLS', description='Minimal Delegated License Service (DLS).', version=VERSION, lifespan=lifespan, **config)
|
||||
|
||||
app.debug = DEBUG
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -54,10 +84,20 @@ app.add_middleware(
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
# logging
|
||||
logging.basicConfig()
|
||||
# Logging
|
||||
LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO
|
||||
logging.basicConfig(format='[{levelname:^7}] [{module:^15}] {message}', style='{')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG if DEBUG else logging.INFO)
|
||||
logger.setLevel(LOG_LEVEL)
|
||||
logging.getLogger('util').setLevel(LOG_LEVEL)
|
||||
logging.getLogger('NV').setLevel(LOG_LEVEL)
|
||||
|
||||
|
||||
# Helper
|
||||
def __get_token(request: Request) -> dict:
|
||||
authorization_header = request.headers.get('authorization')
|
||||
token = authorization_header.split(' ')[1]
|
||||
return jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
|
||||
|
||||
|
||||
def validate_settings():
|
||||
@@ -72,11 +112,7 @@ def validate_settings():
|
||||
session.close()
|
||||
|
||||
|
||||
def __get_token(request: Request, jwt_decode_key: "jose.jwt") -> dict:
|
||||
authorization_header = request.headers.get('authorization')
|
||||
token = authorization_header.split(' ')[1]
|
||||
return jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
|
||||
|
||||
# Endpoints
|
||||
|
||||
@app.get('/', summary='Index')
|
||||
async def index():
|
||||
@@ -118,8 +154,7 @@ async def _config():
|
||||
async def _readme():
|
||||
from markdown import markdown
|
||||
from util import load_file
|
||||
|
||||
content = load_file('../README.md').decode('utf-8')
|
||||
content = load_file(join(dirname(__file__), '../README.md')).decode('utf-8')
|
||||
return HTMLr(markdown(text=content, extensions=['tables', 'fenced_code', 'md_in_html', 'nl2br', 'toc']))
|
||||
|
||||
|
||||
@@ -595,26 +630,6 @@ async def leasing_v1_lessor_shutdown(request: Request):
|
||||
return JSONr(response)
|
||||
|
||||
|
||||
@app.on_event('startup')
|
||||
async def app_on_startup():
|
||||
default_instance = Instance.get_default_instance(db)
|
||||
|
||||
lease_renewal_period = default_instance.lease_renewal_period
|
||||
lease_renewal_delta = default_instance.get_lease_renewal_delta()
|
||||
client_token_expire_delta = default_instance.get_client_token_expire_delta()
|
||||
|
||||
logger.info(f'''
|
||||
Using timezone: {str(TZ)}. Make sure this is correct and match your clients!
|
||||
|
||||
Your clients will renew their license every {str(Lease.calculate_renewal(lease_renewal_period, lease_renewal_delta))}.
|
||||
If the renewal fails, the license is valid for {str(lease_renewal_delta)}.
|
||||
|
||||
Your client-token file (.tok) is valid for {str(client_token_expire_delta)}.
|
||||
''')
|
||||
|
||||
validate_settings()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import uvicorn
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text, BLOB, INT, FLOAT
|
||||
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base, Session, relationship
|
||||
|
||||
from app.util import parse_key
|
||||
from util import NV
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -148,6 +149,8 @@ class Origin(Base):
|
||||
return f'Origin(origin_ref={self.origin_ref}, hostname={self.hostname})'
|
||||
|
||||
def serialize(self) -> dict:
|
||||
_ = NV().find(self.guest_driver_version)
|
||||
|
||||
return {
|
||||
'origin_ref': self.origin_ref,
|
||||
# 'service_instance_xid': self.service_instance_xid,
|
||||
@@ -155,6 +158,7 @@ class Origin(Base):
|
||||
'guest_driver_version': self.guest_driver_version,
|
||||
'os_platform': self.os_platform,
|
||||
'os_version': self.os_version,
|
||||
'$driver': _ if _ is not None else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
60
app/util.py
60
app/util.py
@@ -1,10 +1,17 @@
|
||||
def load_file(filename) -> bytes:
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
|
||||
|
||||
def load_file(filename: str) -> bytes:
|
||||
log = logging.getLogger(f'{__name__}')
|
||||
log.debug(f'Loading contents of file "{filename}')
|
||||
with open(filename, 'rb') as file:
|
||||
content = file.read()
|
||||
return content
|
||||
|
||||
|
||||
def load_key(filename) -> "RsaKey":
|
||||
def load_key(filename: str) -> "RsaKey":
|
||||
try:
|
||||
# Crypto | Cryptodome on Debian
|
||||
from Crypto.PublicKey import RSA
|
||||
@@ -13,6 +20,8 @@ def load_key(filename) -> "RsaKey":
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.PublicKey.RSA import RsaKey
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.debug(f'Importing RSA-Key from "{filename}"')
|
||||
return RSA.import_key(extern_key=load_file(filename), passphrase=None)
|
||||
|
||||
|
||||
@@ -36,5 +45,50 @@ def generate_key() -> "RsaKey":
|
||||
except ModuleNotFoundError:
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.PublicKey.RSA import RsaKey
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.debug(f'Generating RSA-Key')
|
||||
return RSA.generate(bits=2048)
|
||||
|
||||
|
||||
class NV:
|
||||
__DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json'
|
||||
__DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions"
|
||||
|
||||
def __init__(self):
|
||||
self.log = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
if NV.__DRIVER_MATRIX is None:
|
||||
from json import load as json_load
|
||||
try:
|
||||
file = open(NV.__DRIVER_MATRIX_FILENAME)
|
||||
NV.__DRIVER_MATRIX = json_load(file)
|
||||
file.close()
|
||||
self.log.debug(f'Successfully loaded "{NV.__DRIVER_MATRIX_FILENAME}".')
|
||||
except Exception as e:
|
||||
NV.__DRIVER_MATRIX = {} # init empty dict to not try open file everytime, just when restarting app
|
||||
# self.log.warning(f'Failed to load "{NV.__DRIVER_MATRIX_FILENAME}": {e}')
|
||||
|
||||
@staticmethod
|
||||
def find(version: str) -> dict | None:
|
||||
if NV.__DRIVER_MATRIX is None:
|
||||
return None
|
||||
for idx, (key, branch) in enumerate(NV.__DRIVER_MATRIX.items()):
|
||||
for release in branch.get('$releases'):
|
||||
linux_driver = release.get('Linux Driver')
|
||||
windows_driver = release.get('Windows Driver')
|
||||
if version == linux_driver or version == windows_driver:
|
||||
tmp = branch.copy()
|
||||
tmp.pop('$releases')
|
||||
|
||||
is_latest = release.get('vGPU Software') == branch.get('Latest Release in Branch')
|
||||
|
||||
return {
|
||||
'software_branch': branch.get('vGPU Software Branch'),
|
||||
'branch_version': release.get('vGPU Software'),
|
||||
'driver_branch': branch.get('Driver Branch'),
|
||||
'branch_status': branch.get('vGPU Branch Status'),
|
||||
'release_date': release.get('Release Date'),
|
||||
'eol': branch.get('EOL Date') if is_latest else None,
|
||||
'is_latest': is_latest,
|
||||
}
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user