diff --git a/dbHandler.py b/dbHandler.py index f2f208e..faab479 100644 --- a/dbHandler.py +++ b/dbHandler.py @@ -1,3 +1,4 @@ +import sys import typing import psycopg2 from psycopg2 import sql @@ -29,6 +30,15 @@ def connect(databaseOption: str, hostOption: str, userOption: str, passwordOptio errorPrint("Error occurred connecting to database! Exiting...") sys.exit(1) +def disconnect(dbConnection: psycopg2.extensions.connection) -> None: + debugPrint("Attempting to disconnect database...") + try: + dbConnection.close() + debugPrint("Successfully disconnected database!") + except: + debugPrint("Failed to disconnect database! Exiting...") + sys.exit(1) + def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str): dbCursor = dbConnection.cursor() @@ -63,12 +73,16 @@ def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Compos dbCursor.close() return dbResults def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: - debugPrint("Exec query executing...") - dbCursor = dbConnection.cursor() - dbCursor.execute(query) - dbResults = dbCursor.fetchall() - dbCursor.close() - return dbResults + try: + debugPrint("Exec query executing...") + dbCursor = dbConnection.cursor() + dbCursor.execute(query) + dbResults = dbCursor.fetchall() + dbCursor.close() + return dbResults + except Exception as error: + errorPrint("Exec query failed! " + repr(error)) + return None # Callable helper functions def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list): diff --git a/main.py b/main.py index 993f8d3..443a027 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,8 @@ import atexit import signal from typing import Union +from contextlib import asynccontextmanager + from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -11,15 +13,14 @@ from pydantic import BaseModel import dbHandler import userHandler import securityHandler +import tokenHandler -dbConnection = None - +dbConnection = dbHandler.connect("blorgdb", + "172.20.0.10", + "dev", + "dev", + "5432") def apiInit(): - dbConnection = dbHandler.connect("blorgdb", - "172.20.0.10", - "dev", - "dev", - "5432") dbHandler.initTable(dbConnection, "Users", """ ID SERIAL PRIMARY KEY, Username VARCHAR(255), @@ -41,7 +42,7 @@ def apiInit(): """) dbHandler.initTable(dbConnection, "AuthTokens", """ ID SERIAL PRIMARY KEY, - Token VARCHAR(255), + Token VARCHAR(2048), OwnerID INTEGER, DateCreated TIMESTAMP, DateExpiry TIMESTAMP, @@ -58,10 +59,10 @@ def apiInit(): ID SERIAL PRIMARY KEY, Name VARCHAR(255) """) - userHandler.createUser(dbConnection, "testuser", "Test", "User", "A test user", "TestCountry", "TestTheme", "TestColor", "testuser") + # userHandler.createUser(dbConnection, "testuser", "Test", "User", "A test user", "TestCountry", "TestTheme", "TestColor", "testuser") def apiCleanup(): - dbConnection.close() + dbHandler.disconnect(dbConnection) @asynccontextmanager async def apiLifespan(app: FastAPI): @@ -107,18 +108,19 @@ class loginBody(BaseModel): @app.post("/api/login") def postlogin(body: loginBody, request: Request): try: - if userHandler.checkUserExistence(dbConnection, loginBody.username): - userID = userHandler.getIDByUsername(dbConnection, loginBody.username) - if securityHandler.handlePassword(dbConnection, loginBody.password, userID): - return {"success": True, "authToken": tokenHandler.createToken(dbConnection, userID, loginBody.rememberMe, request.client.host), "message": "User login success!"} + if userHandler.checkUserExistence(dbConnection, body.username): + userID = userHandler.getIDByUsername(dbConnection, body.username) + if securityHandler.handlePassword(dbConnection, body.password, userID): + return {"success": True, "authToken": tokenHandler.createToken(dbConnection, userID, body.rememberMe, request.client.host), "message": "User login success!"} else: return {"success": False, "authToken": "none", "message": "User login failed! Please check your password."} else: return {"success": False, "authToken": "none", "message": "User login failed! User does not exist."} - except: - return {"success": False, "authToken": "none", "message": "User login failed! Unexpected server error."} + except Exception as error: + msg = "User login failed! Unexpected server error. " + repr(error) + print(msg) + return {"success": False, "authToken": "none", "message": msg} @app.get("/api") def getapi(): return {"Hello": "API!"} - diff --git a/securityHandler.py b/securityHandler.py index 1d74183..8ed9896 100644 --- a/securityHandler.py +++ b/securityHandler.py @@ -36,12 +36,13 @@ def verifyRehash(hash: str) -> bool: def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool: hash = userHandler.getHashValueByUserID(dbConnection, userID) - debugPrint("Now verifying password against hash for user ID " + userid + "...") + userIDstr = str(userID) + debugPrint("Now verifying password against hash for user ID " + userIDstr + "...") if verifyPassword(password, hash): - debugPrint("(USER ID " + userID + ") Password verification success!") + debugPrint("(USER ID " + userIDstr + ") Password verification success!") if verifyRehash(hash): - debugPrint("(USER ID " + userID + ") Hash needs to be rehashed! Will now rehash...") + debugPrint("(USER ID " + userIDstr + ") Hash needs to be rehashed! Will now rehash...") return True else: - debugPrint("(USER ID " + userID + ") Password verification failure!") + debugPrint("(USER ID " + userIDstr + ") Password verification failure!") return False diff --git a/tokenHandler.py b/tokenHandler.py index 5788058..9a4beff 100644 --- a/tokenHandler.py +++ b/tokenHandler.py @@ -1,3 +1,5 @@ +import datetime +import secrets import psycopg2 from psycopg2 import sql @@ -10,5 +12,15 @@ def debugPrint(msg: str) -> None: print("(TOKEN HANDLER) PRINT: " + msg) def createToken(dbConnection: psycopg2.extensions.connection, userID: int, rememberMe: bool, locationIP: str) -> str: - debugPrint("Now creating new token with following attributes : userID = " + str(userID) + ", rememberMe = " + str(rememberMe) + ", locationIP = " + locationIP + "...") - return "sha" + debugPrint("Now initialising new token with following attributes : userID = " + str(userID) + ", rememberMe = " + str(rememberMe) + ", locationIP = " + locationIP + "...") + randToken = secrets.token_hex(1023) + dateCreated = datetime.datetime.now() + if rememberMe: + dateExpiry = dateCreated + datetime.timedelta(days=30) + else: + dateExpiry = dateCreated + datetime.timedelta(days=1) + dbHandler.insertRow(dbConnection, + 'authtokens', + ['token', 'ownerid', 'datecreated', 'dateexpiry', 'iplocationcreated'], + [randToken, userID, dateCreated.strftime("%G-%m-%d %X"), dateExpiry.strftime("%G-%m-%d %X"), locationIP]) + return randToken diff --git a/userHandler.py b/userHandler.py index eb1bd73..a11083f 100644 --- a/userHandler.py +++ b/userHandler.py @@ -11,7 +11,7 @@ def debugPrint(msg: str) -> None: print("(USER HANDLER) PRINT: " + msg) def createUser(dbConnection: psycopg2.extensions.connection, username: str, firstName: str, lastName: str, description: str, country: str, theme: str, accentColor: str, password: str): - debugPrint("Now creating new user with following attributes : username = " + username + ", firstName = " + firstName + ", lastName = " + lastName + ", description = " + description + ", country = " + country + ", theme = " + theme + ", accentColor = " + accentColor) + debugPrint("Now creating new user with following attributes : username = " + username + ", firstName = " + firstName + ", lastName = " + lastName + ", description = " + description + ", country = " + country + ", theme = " + theme + ", accentColor = " + accentColor + "...") dbHandler.insertRow(dbConnection, 'users', ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'], @@ -20,7 +20,7 @@ def createUser(dbConnection: psycopg2.extensions.connection, username: str, firs def checkUserExistence(dbConnection: psycopg2.extensions.connection, username: str) -> bool: return dbHandler.checkFieldValueExistence(dbConnection, "users", "username", username) -def getHashValuebyUserID(dbConnection: psycopg2.extensions.connection, userID: int) -> str: +def getHashValueByUserID(dbConnection: psycopg2.extensions.connection, userID: int) -> str: return dbHandler.getFieldValueByID(dbConnection, "users", userID, "passwordhash") def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int: