main.py - refactorings & added simple "dataset" database

This commit is contained in:
Oscar Krause 2022-12-20 14:55:07 +01:00
parent 3c252e2a4c
commit 7ab5e7b264

View File

@ -1,4 +1,4 @@
from base64 import b64encode from base64 import b64encode as b64enc
from hashlib import sha256 from hashlib import sha256
from uuid import uuid4 from uuid import uuid4
from os.path import join, dirname from os.path import join, dirname
@ -12,6 +12,7 @@ from calendar import timegm
from jose import jws, jwk, jwt from jose import jws, jwk, jwt
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from starlette.responses import StreamingResponse, JSONResponse from starlette.responses import StreamingResponse, JSONResponse
import dataset
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import RsaKey from Crypto.PublicKey.RSA import RsaKey
@ -28,7 +29,7 @@ def load_key(filename) -> RsaKey:
# todo: initialize certificate (or should be done by user, and passed through "volumes"?) # todo: initialize certificate (or should be done by user, and passed through "volumes"?)
app = FastAPI() app, db = FastAPI(), dataset.connect('sqlite:///db.sqlite')
LEASE_EXPIRE_DELTA = relativedelta(minutes=15) # days=90 LEASE_EXPIRE_DELTA = relativedelta(minutes=15) # days=90
@ -38,6 +39,9 @@ SITE_KEY_XID = getenv('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000')
INSTANCE_KEY_RSA = load_key(join(dirname(__file__), 'cert/instance.private.pem')) INSTANCE_KEY_RSA = load_key(join(dirname(__file__), 'cert/instance.private.pem'))
INSTANCE_KEY_PUB = load_key(join(dirname(__file__), 'cert/instance.public.pem')) INSTANCE_KEY_PUB = load_key(join(dirname(__file__), 'cert/instance.public.pem'))
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
@app.get('/') @app.get('/')
async def index(): async def index():
@ -91,8 +95,7 @@ async def client_token():
"service_instance_public_key_configuration": service_instance_public_key_configuration, "service_instance_public_key_configuration": service_instance_public_key_configuration,
} }
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256) data = jws.sign(payload, key=jwt_encode_key, headers=None, algorithm='RS256')
data = jws.sign(payload, key=key, headers=None, algorithm='RS256')
response = StreamingResponse(iter([data]), media_type="text/plain") response = StreamingResponse(iter([data]), media_type="text/plain")
filename = f'client_configuration_token_{datetime.now().strftime("%d-%m-%y-%H-%M-%S")}' filename = f'client_configuration_token_{datetime.now().strftime("%d-%m-%y-%H-%M-%S")}'
@ -101,17 +104,25 @@ async def client_token():
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_origins_controller.py # venv/lib/python3.9/site-packages/nls_services_auth/test/test_origins_controller.py
# {"candidate_origin_ref":"00112233-4455-6677-8899-aabbccddeeff","environment":{"fingerprint":{"mac_address_list":["ff:ff:ff:ff:ff:ff"]},"hostname":"my-hostname","ip_address_list":["192.168.178.123","fe80::","fe80::1%enp6s18"],"guest_driver_version":"510.85.02","os_platform":"Debian GNU/Linux 11 (bullseye) 11","os_version":"11 (bullseye)"},"registration_pending":false,"update_pending":false}
@app.post('/auth/v1/origin') @app.post('/auth/v1/origin')
async def auth_origin(request: Request): async def auth_origin(request: Request):
body = await request.body() j = json.loads((await request.body()).decode('utf-8'))
body = body.decode('utf-8')
j = json.loads(body) candidate_origin_ref = j['candidate_origin_ref']
# {"candidate_origin_ref":"00112233-4455-6677-8899-aabbccddeeff","environment":{"fingerprint":{"mac_address_list":["ff:ff:ff:ff:ff:ff"]},"hostname":"my-hostname","ip_address_list":["192.168.178.123","fe80::","fe80::1%enp6s18"],"guest_driver_version":"510.85.02","os_platform":"Debian GNU/Linux 11 (bullseye) 11","os_version":"11 (bullseye)"},"registration_pending":false,"update_pending":false} print(f'> [ origin ]: {candidate_origin_ref}: {j}')
print(f'> [ origin ]: {j}')
data = dict(
candidate_origin_ref=candidate_origin_ref,
hostname=j['environment']['hostname'],
guest_driver_version=j['environment']['guest_driver_version'],
os_platform=j['environment']['os_platform'], os_version=j['environment']['os_version']
)
db['origin'].insert_ignore(data, ['candidate_origin_ref'])
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
response = { response = {
"origin_ref": j['candidate_origin_ref'], "origin_ref": candidate_origin_ref,
"environment": j['environment'], "environment": j['environment'],
"svc_port_set_list": None, "svc_port_set_list": None,
"node_url_list": None, "node_url_list": None,
@ -124,13 +135,13 @@ async def auth_origin(request: Request):
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py # venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py
# venv/lib/python3.9/site-packages/nls_core_auth/auth.py - CodeResponse # venv/lib/python3.9/site-packages/nls_core_auth/auth.py - CodeResponse
# {"code_challenge":"...","origin_ref":"00112233-4455-6677-8899-aabbccddeeff"}
@app.post('/auth/v1/code') @app.post('/auth/v1/code')
async def auth_code(request: Request): async def auth_code(request: Request):
body = await request.body() j = json.loads((await request.body()).decode('utf-8'))
body = body.decode('utf-8')
j = json.loads(body) origin_ref = j['origin_ref']
# {"code_challenge":"...","origin_ref":"00112233-4455-6677-8899-aabbccddeeff"} print(f'> [ code ]: {origin_ref}: {j}')
print(f'> [ code ]: {j}')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
expires = cur_time + relativedelta(days=1) expires = cur_time + relativedelta(days=1)
@ -144,12 +155,7 @@ async def auth_code(request: Request):
'kid': SITE_KEY_XID 'kid': SITE_KEY_XID
} }
headers = None auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm='RS256')
kid = payload.get('kid')
if kid:
headers = {'kid': kid}
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
auth_code = jws.sign(payload, key, headers=headers, algorithm='RS256')
response = { response = {
"auth_code": auth_code, "auth_code": auth_code,
@ -161,19 +167,17 @@ async def auth_code(request: Request):
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py # venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py
# venv/lib/python3.9/site-packages/nls_core_auth/auth.py - TokenResponse # venv/lib/python3.9/site-packages/nls_core_auth/auth.py - TokenResponse
# {"auth_code":"...","code_verifier":"..."}
@app.post('/auth/v1/token') @app.post('/auth/v1/token')
async def auth_token(request: Request): async def auth_token(request: Request):
body = await request.body() j = json.loads((await request.body()).decode('utf-8'))
body = body.decode('utf-8') payload = jwt.decode(token=j['auth_code'], key=jwt_decode_key)
j = json.loads(body)
# {"auth_code":"...","code_verifier":"..."}
# payload = self._security.get_valid_payload(req.auth_code) # todo origin_ref = payload['origin_ref']
key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512) print(f'> [ auth ]: {origin_ref}: {j}')
payload = jwt.decode(token=j['auth_code'], key=key)
# validate the code challenge # validate the code challenge
if payload['challenge'] != b64encode(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'): if payload['challenge'] != b64enc(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'):
raise HTTPException(status_code=403, detail='expected challenge did not match verifier') raise HTTPException(status_code=403, detail='expected challenge did not match verifier')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
@ -190,12 +194,7 @@ async def auth_token(request: Request):
'kid': SITE_KEY_XID, 'kid': SITE_KEY_XID,
} }
headers = None auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm='RS256')
kid = payload.get('kid')
if kid:
headers = {'kid': kid}
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
auth_token = jwt.encode(new_payload, key=key, headers=headers, algorithm='RS256')
response = { response = {
"expires": access_expires_on.isoformat(), "expires": access_expires_on.isoformat(),
@ -206,18 +205,20 @@ async def auth_token(request: Request):
return JSONResponse(response) return JSONResponse(response)
# {'fulfillment_context': {'fulfillment_class_ref_list': []}, 'lease_proposal_list': [{'license_type_qualifiers': {'count': 1}, 'product': {'name': 'NVIDIA RTX Virtual Workstation'}}], 'proposal_evaluation_mode': 'ALL_OF', 'scope_ref_list': ['00112233-4455-6677-8899-aabbccddeeff']}
@app.post('/leasing/v1/lessor') @app.post('/leasing/v1/lessor')
async def leasing_lessor(request: Request): async def leasing_lessor(request: Request):
body = await request.body() j = json.loads((await request.body()).decode('utf-8'))
body = body.decode('utf-8') token = jwt.decode(request.headers['authorization'].split(' ')[1], key=jwt_decode_key, algorithms='RS256', options={'verify_aud': False})
j = json.loads(body)
# {'fulfillment_context': {'fulfillment_class_ref_list': []}, 'lease_proposal_list': [{'license_type_qualifiers': {'count': 1}, 'product': {'name': 'NVIDIA RTX Virtual Workstation'}}], 'proposal_evaluation_mode': 'ALL_OF', 'scope_ref_list': ['00112233-4455-6677-8899-aabbccddeeff']} code_challenge = token['origin_ref']
print(f'> [ lessor ]: {j}') scope_ref_list = j['scope_ref_list']
print(f'> [ lessor ]: {code_challenge}: {j}')
print(f'> {code_challenge}: create leases for scope_ref_list {scope_ref_list}')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
# todo: keep track of leases, to return correct list on '/leasing/v1/lessor/leases'
lease_result_list = [] lease_result_list = []
for scope_ref in j['scope_ref_list']: for scope_ref in scope_ref_list:
lease_result_list.append({ lease_result_list.append({
"ordinal": 0, "ordinal": 0,
"lease": { "lease": {
@ -229,6 +230,8 @@ async def leasing_lessor(request: Request):
"license_type": "CONCURRENT_COUNTED_SINGLE" "license_type": "CONCURRENT_COUNTED_SINGLE"
} }
}) })
data = dict(origin_ref=code_challenge, lease_ref=scope_ref, expires=None, last_update=None)
db['leases'].insert_ignore(data, ['origin_ref', 'lease_ref'])
response = { response = {
"lease_result_list": lease_result_list, "lease_result_list": lease_result_list,
@ -243,13 +246,23 @@ async def leasing_lessor(request: Request):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.get('/leasing/v1/lessor/leases') @app.get('/leasing/v1/lessor/leases')
async def leasing_lessor_lease(request: Request): async def leasing_lessor_lease(request: Request):
token = jwt.decode(request.headers['authorization'].split(' ')[1], key=key, algorithms='RS256', options={'verify_aud': False})
code_challenge = token['origin_ref']
active_lease_list = list(map(lambda x: x['lease_ref'], db['leases'].find(origin_ref=code_challenge)))
print(f'> {code_challenge}: found {len(active_lease_list)} active leases')
if len(active_lease_list) == 0:
raise HTTPException(status_code=400, detail="No leases available")
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
# venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql # venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql
response = { response = {
# GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE # "active_lease_list": [
"active_lease_list": [ # # "BE276D7B-2CDB-11EC-9838-061A22468B59" # (works on Linux) GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE // 'NVIDIA Virtual PC','NVIDIA Virtual PC'
"BE276D7B-2CDB-11EC-9838-061A22468B59" # "BE276EFE-2CDB-11EC-9838-061A22468B59" # GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE // 'NVIDIA RTX Virtual Workstation','NVIDIA RTX Virtual Workstation
], # ],
"active_lease_list": active_lease_list,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
"prompts": None "prompts": None
} }
@ -260,26 +273,44 @@ async def leasing_lessor_lease(request: Request):
# venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py # venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py
@app.put('/leasing/v1/lease/{lease_ref}') @app.put('/leasing/v1/lease/{lease_ref}')
async def leasing_lease_renew(request: Request, lease_ref: str): async def leasing_lease_renew(request: Request, lease_ref: str):
print(f'> [ renew ]: lease: {lease_ref}') token = jwt.decode(request.headers['authorization'].split(' ')[1], key=jwt_decode_key, algorithms='RS256', options={'verify_aud': False})
code_challenge = token['origin_ref']
print(f'> {code_challenge}: renew {lease_ref}')
if db['leases'].count(lease_ref=lease_ref) == 0:
raise HTTPException(status_code=400, detail="No leases available")
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
expires = cur_time + LEASE_EXPIRE_DELTA
response = { response = {
"lease_ref": lease_ref, "lease_ref": lease_ref,
"expires": (cur_time + LEASE_EXPIRE_DELTA).isoformat(), "expires": expires.isoformat(),
"recommended_lease_renewal": 0.16, "recommended_lease_renewal": 0.16,
# 0.16 = 10 min, 0.25 = 15 min, 0.33 = 20 min, 0.5 = 30 min (should be lower than "LEASE_EXPIRE_DELTA")
"offline_lease": True, "offline_lease": True,
"prompts": None, "prompts": None,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
} }
data = dict(lease_ref=lease_ref, origin_ref=code_challenge, expires=expires, last_update=cur_time)
db['leases'].update(data, ['lease_ref'])
return JSONResponse(response) return JSONResponse(response)
@app.delete('/leasing/v1/lessor/leases') @app.delete('/leasing/v1/lessor/leases')
async def leasing_lessor_lease_remove(request: Request): async def leasing_lessor_lease_remove(request: Request):
token = jwt.decode(request.headers['authorization'].split(' ')[1], key=jwt_decode_key, algorithms='RS256', options={'verify_aud': False})
code_challenge = token['origin_ref']
released_lease_list = list(map(lambda x: x['lease_ref'], db['leases'].find(origin_ref=code_challenge)))
deletions = db['leases'].delete(origin_ref=code_challenge)
print(f'> {code_challenge}: removed {deletions} leases')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
response = { response = {
"released_lease_list": None, "released_lease_list": released_lease_list,
"release_failure_list": None, "release_failure_list": None,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
"prompts": None "prompts": None