import logging from base64 import b64encode as b64enc from calendar import timegm from contextlib import asynccontextmanager from datetime import datetime, timedelta from hashlib import sha256 from json import loads as json_loads from os import getenv as env from os.path import join, dirname from uuid import uuid4 from dateutil.relativedelta import relativedelta from dotenv import load_dotenv from fastapi import FastAPI from fastapi.requests import Request from jose import jws, jwk, jwt, JWTError from jose.constants import ALGORITHMS from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from starlette.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse from orm import init as db_init, migrate, Site, Instance, Origin, Lease # Load variables load_dotenv('../version.env') # Get current timezone TZ = datetime.now().astimezone().tzinfo # Load basic variables VERSION, COMMIT, DEBUG = env('VERSION', 'unknown'), env('COMMIT', 'unknown'), bool(env('DEBUG', False)) # Database connection db = create_engine(str(env('DATABASE', 'sqlite:///db.sqlite'))) db_init(db), migrate(db) # Load DLS variables (all prefixed with "INSTANCE_*" is used as "SERVICE_INSTANCE_*" or "SI_*" in official dls service) DLS_URL = str(env('DLS_URL', 'localhost')) DLS_PORT = int(env('DLS_PORT', '443')) CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}'] ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001')) # todo # FastAPI @asynccontextmanager async def lifespan(_: FastAPI): # on startup default_instance = Instance.get_default_instance(db) lease_renewal_period = default_instance.lease_renewal_period lease_renewal_delta = default_instance.get_lease_renewal_delta() client_token_expire_delta = default_instance.get_client_token_expire_delta() logger.info(f''' Using timezone: {str(TZ)}. Make sure this is correct and match your clients! Your clients will renew their license every {str(Lease.calculate_renewal(lease_renewal_period, lease_renewal_delta))}. If the renewal fails, the license is valid for {str(lease_renewal_delta)}. Your client-token file (.tok) is valid for {str(client_token_expire_delta)}. ''') logger.info(f'Debug is {"enabled" if DEBUG else "disabled"}.') validate_settings() yield # on shutdown logger.info(f'Shutting down ...') config = dict(openapi_url=None, docs_url=None, redoc_url=None) # dict(openapi_url='/-/openapi.json', docs_url='/-/docs', redoc_url='/-/redoc') app = FastAPI(title='FastAPI-DLS', description='Minimal Delegated License Service (DLS).', version=VERSION, lifespan=lifespan, **config) app.debug = DEBUG app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS, allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) # Logging LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO logging.basicConfig(format='[{levelname:^7}] [{module:^15}] {message}', style='{') logger = logging.getLogger(__name__) logger.setLevel(LOG_LEVEL) logging.getLogger('util').setLevel(LOG_LEVEL) logging.getLogger('NV').setLevel(LOG_LEVEL) # Helper def __get_token(request: Request) -> dict: authorization_header = request.headers.get('authorization') token = authorization_header.split(' ')[1] return jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False}) def validate_settings(): session = sessionmaker(bind=db)() lease_expire_delta_min, lease_expire_delta_max = 86_400, 7_776_000 for instance in session.query(Instance).all(): lease_expire_delta = instance.lease_expire_delta if lease_expire_delta < 86_400 or lease_expire_delta > 7_776_000: logging.warning(f'> [ instance ]: {instance.instance_ref}: "lease_expire_delta" should be between {lease_expire_delta_min} and {lease_expire_delta_max}') session.close() # Endpoints @app.get('/', summary='Index') async def index(): return RedirectResponse('/-/readme') @app.get('/-/', summary='* Index') async def _index(): return RedirectResponse('/-/readme') @app.get('/-/health', summary='* Health') async def _health(): return JSONr({'status': 'up'}) @app.get('/-/config', summary='* Config', description='returns environment variables.') async def _config(): default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db) return JSONr({ 'VERSION': str(VERSION), 'COMMIT': str(COMMIT), 'DEBUG': str(DEBUG), 'DLS_URL': str(DLS_URL), 'DLS_PORT': str(DLS_PORT), 'SITE_KEY_XID': str(default_site.site_key), 'INSTANCE_REF': str(default_instance.instance_ref), 'ALLOTMENT_REF': [str(ALLOTMENT_REF)], 'TOKEN_EXPIRE_DELTA': str(default_instance.get_token_expire_delta()), 'LEASE_EXPIRE_DELTA': str(default_instance.get_lease_expire_delta()), 'LEASE_RENEWAL_PERIOD': str(default_instance.lease_renewal_period), 'CORS_ORIGINS': str(CORS_ORIGINS), 'TZ': str(TZ), }) @app.get('/-/readme', summary='* Readme') async def _readme(): from markdown import markdown from util import load_file content = load_file(join(dirname(__file__), '../README.md')).decode('utf-8') return HTMLr(markdown(text=content, extensions=['tables', 'fenced_code', 'md_in_html', 'nl2br', 'toc'])) @app.get('/-/manage', summary='* Management UI') async def _manage(request: Request): response = ''' FastAPI-DLS Management ''' return HTMLr(response) @app.get('/-/origins', summary='* Origins') async def _origins(request: Request, leases: bool = False): session = sessionmaker(bind=db)() response = [] for origin in session.query(Origin).all(): x = origin.serialize() if leases: x['leases'] = list(map(lambda _: _.serialize(), Lease.find_by_origin_ref(db, origin.origin_ref))) response.append(x) session.close() return JSONr(response) @app.delete('/-/origins', summary='* Origins') async def _origins_delete(request: Request): Origin.delete(db) return Response(status_code=201) @app.get('/-/leases', summary='* Leases') async def _leases(request: Request, origin: bool = False): session = sessionmaker(bind=db)() response = [] for lease in session.query(Lease).all(): x = lease.serialize() if origin: lease_origin = session.query(Origin).filter(Origin.origin_ref == lease.origin_ref).first() if lease_origin is not None: x['origin'] = lease_origin.serialize() response.append(x) session.close() return JSONr(response) @app.delete('/-/leases/expired', summary='* Leases') async def _lease_delete_expired(request: Request): Lease.delete_expired(db) return Response(status_code=201) @app.delete('/-/lease/{lease_ref}', summary='* Lease') async def _lease_delete(request: Request, lease_ref: str): if Lease.delete(db, lease_ref) == 1: return Response(status_code=201) return JSONr(status_code=404, content={'status': 404, 'detail': 'lease not found'}) # venv/lib/python3.9/site-packages/nls_core_service_instance/service_instance_token_manager.py @app.get('/-/client-token', summary='* Client-Token', description='creates a new messenger token for this service instance') async def _client_token(): cur_time = datetime.utcnow() default_instance = Instance.get_default_instance(db) public_key = default_instance.get_public_key() # todo: implemented request parameter to support different instances jwt_encode_key = default_instance.get_jwt_encode_key() exp_time = cur_time + default_instance.get_client_token_expire_delta() payload = { "jti": str(uuid4()), "iss": "NLS Service Instance", "aud": "NLS Licensed Client", "iat": timegm(cur_time.timetuple()), "nbf": timegm(cur_time.timetuple()), "exp": timegm(exp_time.timetuple()), "update_mode": "ABSOLUTE", "scope_ref_list": [ALLOTMENT_REF], "fulfillment_class_ref_list": [], "service_instance_configuration": { "nls_service_instance_ref": default_instance.instance_ref, "svc_port_set_list": [ { "idx": 0, "d_name": "DLS", "svc_port_map": [{"service": "auth", "port": DLS_PORT}, {"service": "lease", "port": DLS_PORT}] } ], "node_url_list": [{"idx": 0, "url": DLS_URL, "url_qr": DLS_URL, "svc_port_set_idx": 0}] }, "service_instance_public_key_configuration": { "service_instance_public_key_me": { "mod": hex(public_key.public_key().n)[2:], "exp": int(public_key.public_key().e), }, "service_instance_public_key_pem": public_key.export_key().decode('utf-8'), "key_retention_mode": "LATEST_ONLY" }, } content = jws.sign(payload, key=jwt_encode_key, headers=None, algorithm=ALGORITHMS.RS256) response = StreamingResponse(iter([content]), media_type="text/plain") filename = f'client_configuration_token_{datetime.now().strftime("%d-%m-%y-%H-%M-%S")}.tok' response.headers["Content-Disposition"] = f'attachment; filename={filename}' return response # venv/lib/python3.9/site-packages/nls_services_auth/test/test_origins_controller.py @app.post('/auth/v1/origin', description='find or create an origin') async def auth_v1_origin(request: Request): j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow() origin_ref = j.get('candidate_origin_ref') logging.info(f'> [ origin ]: {origin_ref}: {j}') data = Origin( origin_ref=origin_ref, hostname=j.get('environment').get('hostname'), guest_driver_version=j.get('environment').get('guest_driver_version'), os_platform=j.get('environment').get('os_platform'), os_version=j.get('environment').get('os_version'), ) Origin.create_or_update(db, data) response = { "origin_ref": origin_ref, "environment": j.get('environment'), "svc_port_set_list": None, "node_url_list": None, "node_query_order": None, "prompts": None, "sync_timestamp": cur_time.isoformat() } return JSONr(response) # venv/lib/python3.9/site-packages/nls_services_auth/test/test_origins_controller.py @app.post('/auth/v1/origin/update', description='update an origin evidence') async def auth_v1_origin_update(request: Request): j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow() origin_ref = j.get('origin_ref') logging.info(f'> [ update ]: {origin_ref}: {j}') data = Origin( origin_ref=origin_ref, hostname=j.get('environment').get('hostname'), guest_driver_version=j.get('environment').get('guest_driver_version'), os_platform=j.get('environment').get('os_platform'), os_version=j.get('environment').get('os_version'), ) Origin.create_or_update(db, data) response = { "environment": j.get('environment'), "prompts": None, "sync_timestamp": cur_time.isoformat() } return JSONr(response) # venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py # venv/lib/python3.9/site-packages/nls_core_auth/auth.py - CodeResponse @app.post('/auth/v1/code', description='get an authorization code') async def auth_v1_code(request: Request): j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow() origin_ref = j.get('origin_ref') logging.info(f'> [ code ]: {origin_ref}: {j}') delta = relativedelta(minutes=15) expires = cur_time + delta default_site = Site.get_default_site(db) jwt_encode_key = Instance.get_default_instance(db).get_jwt_encode_key() payload = { 'iat': timegm(cur_time.timetuple()), 'exp': timegm(expires.timetuple()), 'challenge': j.get('code_challenge'), 'origin_ref': j.get('origin_ref'), 'key_ref': default_site.site_key, 'kid': default_site.site_key, } auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256) response = { "auth_code": auth_code, "sync_timestamp": cur_time.isoformat(), "prompts": None } return JSONr(response) # venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py # venv/lib/python3.9/site-packages/nls_core_auth/auth.py - TokenResponse @app.post('/auth/v1/token', description='exchange auth code and verifier for token') async def auth_v1_token(request: Request): j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow() default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db) jwt_encode_key, jwt_decode_key = default_instance.get_jwt_encode_key(), default_instance.get_jwt_decode_key() try: payload = jwt.decode(token=j.get('auth_code'), key=jwt_decode_key, algorithms=[ALGORITHMS.RS256]) except JWTError as e: return JSONr(status_code=400, content={'status': 400, 'title': 'invalid token', 'detail': str(e)}) origin_ref = payload.get('origin_ref') logging.info(f'> [ auth ]: {origin_ref}: {j}') # validate the code challenge challenge = b64enc(sha256(j.get('code_verifier').encode('utf-8')).digest()).rstrip(b'=').decode('utf-8') if payload.get('challenge') != challenge: return JSONr(status_code=401, content={'status': 401, 'detail': 'expected challenge did not match verifier'}) access_expires_on = cur_time + default_instance.get_token_expire_delta() new_payload = { 'iat': timegm(cur_time.timetuple()), 'nbf': timegm(cur_time.timetuple()), 'iss': 'https://cls.nvidia.org', 'aud': 'https://cls.nvidia.org', 'exp': timegm(access_expires_on.timetuple()), 'origin_ref': origin_ref, 'key_ref': default_site.site_key, 'kid': default_site.site_key, } auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256) response = { "expires": access_expires_on.isoformat(), "auth_token": auth_token, "sync_timestamp": cur_time.isoformat(), } return JSONr(response) # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py @app.post('/leasing/v1/lessor', description='request multiple leases (borrow) for current origin') async def leasing_v1_lessor(request: Request): j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow() default_instance = Instance.get_default_instance(db) jwt_decode_key = default_instance.get_jwt_decode_key() try: token = __get_token(request, jwt_decode_key) except JWTError: return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'}) origin_ref = token.get('origin_ref') scope_ref_list = j.get('scope_ref_list') logging.info(f'> [ create ]: {origin_ref}: create leases for scope_ref_list {scope_ref_list}') lease_result_list = [] for scope_ref in scope_ref_list: # if scope_ref not in [ALLOTMENT_REF]: # return JSONr(status_code=500, detail=f'no service instances found for scopes: ["{scope_ref}"]') lease_ref = str(uuid4()) expires = cur_time + default_instance.get_lease_expire_delta() lease_result_list.append({ "ordinal": 0, # https://docs.nvidia.com/license-system/latest/nvidia-license-system-user-guide/index.html "lease": { "ref": lease_ref, "created": cur_time.isoformat(), "expires": expires.isoformat(), "recommended_lease_renewal": default_instance.lease_renewal_period, "offline_lease": "true", "license_type": "CONCURRENT_COUNTED_SINGLE" } }) data = Lease(instance_ref=default_instance.instance_ref, origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires) Lease.create_or_update(db, data) response = { "lease_result_list": lease_result_list, "result_code": "SUCCESS", "sync_timestamp": cur_time.isoformat(), "prompts": None } return JSONr(response) # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py # venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql @app.get('/leasing/v1/lessor/leases', description='get active leases for current origin') async def leasing_v1_lessor_lease(request: Request): cur_time = datetime.utcnow() jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key() try: token = __get_token(request, jwt_decode_key) except JWTError: return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'}) origin_ref = token.get('origin_ref') active_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref))) logging.info(f'> [ leases ]: {origin_ref}: found {len(active_lease_list)} active leases') response = { "active_lease_list": active_lease_list, "sync_timestamp": cur_time.isoformat(), "prompts": None } return JSONr(response) # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_single_controller.py # venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py @app.put('/leasing/v1/lease/{lease_ref}', description='renew a lease') async def leasing_v1_lease_renew(request: Request, lease_ref: str): cur_time = datetime.utcnow() default_instance = Instance.get_default_instance(db) jwt_decode_key = default_instance.get_jwt_decode_key() try: token = __get_token(request, jwt_decode_key) except JWTError: return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'}) origin_ref = token.get('origin_ref') logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}') entity = Lease.find_by_origin_ref_and_lease_ref(db, origin_ref, lease_ref) if entity is None: return JSONr(status_code=404, content={'status': 404, 'detail': 'requested lease not available'}) expires = cur_time + default_instance.get_lease_expire_delta() response = { "lease_ref": lease_ref, "expires": expires.isoformat(), "recommended_lease_renewal": default_instance.lease_renewal_period, "offline_lease": True, "prompts": None, "sync_timestamp": cur_time.isoformat(), } Lease.renew(db, entity, expires, cur_time) return JSONr(response) # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_single_controller.py @app.delete('/leasing/v1/lease/{lease_ref}', description='release (return) a lease') async def leasing_v1_lease_delete(request: Request, lease_ref: str): cur_time = datetime.utcnow() jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key() try: token = __get_token(request, jwt_decode_key) except JWTError: return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'}) origin_ref = token.get('origin_ref') logging.info(f'> [ return ]: {origin_ref}: return {lease_ref}') entity = Lease.find_by_lease_ref(db, lease_ref) if entity.origin_ref != origin_ref: return JSONr(status_code=403, content={'status': 403, 'detail': 'access or operation forbidden'}) if entity is None: return JSONr(status_code=404, content={'status': 404, 'detail': 'requested lease not available'}) if Lease.delete(db, lease_ref) == 0: return JSONr(status_code=404, content={'status': 404, 'detail': 'lease not found'}) response = { "lease_ref": lease_ref, "prompts": None, "sync_timestamp": cur_time.isoformat(), } return JSONr(response) # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py @app.delete('/leasing/v1/lessor/leases', description='release all leases') async def leasing_v1_lessor_lease_remove(request: Request): cur_time = datetime.utcnow() jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key() try: token = __get_token(request, jwt_decode_key) except JWTError: return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'}) origin_ref = token.get('origin_ref') released_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref))) deletions = Lease.cleanup(db, origin_ref) logging.info(f'> [ remove ]: {origin_ref}: removed {deletions} leases') response = { "released_lease_list": released_lease_list, "release_failure_list": None, "sync_timestamp": cur_time.isoformat(), "prompts": None } return JSONr(response) @app.post('/leasing/v1/lessor/shutdown', description='shutdown all leases') async def leasing_v1_lessor_shutdown(request: Request): j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow() jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key() token = j.get('token') token = jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False}) origin_ref = token.get('origin_ref') released_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref))) deletions = Lease.cleanup(db, origin_ref) logging.info(f'> [ shutdown ]: {origin_ref}: removed {deletions} leases') response = { "released_lease_list": released_lease_list, "release_failure_list": None, "sync_timestamp": cur_time.isoformat(), "prompts": None } return JSONr(response) if __name__ == '__main__': import uvicorn ### # # Running `python app/main.py` assumes that the user created a keypair, e.g. with openssl. # # openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout app/cert/webserver.key -out app/cert/webserver.crt # ### logging.info(f'> Starting dev-server ...') ssl_keyfile = join(dirname(__file__), 'cert/webserver.key') ssl_certfile = join(dirname(__file__), 'cert/webserver.crt') uvicorn.run('main:app', host='0.0.0.0', port=443, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, reload=True)