diff --git a/dbHandler.py b/dbHandler.py index 0d17ba6..4bc5cd9 100644 --- a/dbHandler.py +++ b/dbHandler.py @@ -2,18 +2,29 @@ import typing import psycopg2 from psycopg2 import sql +debug: bool = True + def debugPrint(msg: str) -> None: - print(msg) + if debug: + print("(DB HANDLER) PRINT: " + msg) def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> None: - print("(DB HANDLER) " + dbConnection.notices[i]) + if debug: + print("(DB HANDLER) " + dbConnection.notices[i]) def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection: - return psycopg2.connect(database=databaseOption, + debugPrint("Attempting to connect to database...") + try: + dbConnection = psycopg2.connect(database=databaseOption, host=hostOption, user=userOption, password=passwordOption, port=portOption) + debugPrint("Successfully connected to database!") + return dbConnection + except: + debugPrint("Error occurred connecting to database! Exiting...") + sys.exit(1) def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str): dbCursor = dbConnection.cursor() @@ -55,9 +66,58 @@ def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tabl ) ) debugPrint(sanitisedQuery.as_string(dbConnection)) - commitQuery(dbConnection, sanitisedQuery) + _commitQuery(dbConnection, sanitisedQuery) -def commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: +def getFieldByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str: + debugPrint("Attempting to get field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...") + sanitisedQuery = sql.SQL(""" + SELECT {field} FROM {table} WHERE "id" = {id} + """).format( + table=sql.Identifier(tableName), + field=sql.Identifier(tableField), + id=sql.Literal(RowID) + ) + return _execQuery(dbConnection, sanitisedQuery)[0][0] + +def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple: + debugPrint("Attempting to get row by ID " + str(RowID) + " in table name " + tableName + "...") + sanitisedQuery = sql.SQL(""" + SELECT * FROM {table} WHERE "id" = {id} + """).format( + table=sql.Identifier(tableName), + id=sql.Literal(RowID) + ) + return _execQuery(dbConnection, sanitisedQuery)[0] + +def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int: + debugPrint("Attempting to get ID by username " + username + "...") + sanitisedQuery = sql.SQL(""" + SELECT id FROM users WHERE "username" = {username} + """).format( + username=sql.Literal(username) + ) + return _execQuery(dbConnection, sanitisedQuery)[0][0] + +def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool: + debugPrint("Checking if field name " + fieldName + " in " + tableName + " with value " + str(fieldValue) + " exists...") + sanitisedQuery = sql.SQL(""" + SELECT EXISTS( + SELECT + {fieldName} + FROM + {table} + WHERE + {fieldName} = {fieldValue} + ); + """).format( + table=sql.Identifier(tableName), + fieldName=sql.Identifier(fieldName), + fieldValue=sql.Literal(fieldValue) + ) + return bool(_execQuery(dbConnection, sanitisedQuery)[0][0]) + +# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks) +def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: debugPrint("Commit query executing...") dbCursor = dbConnection.cursor() dbCursor.execute(query) @@ -65,8 +125,7 @@ def commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composa dbResults = dbCursor.fetchall() dbCursor.close() return dbResults - -def execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: +def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: debugPrint("Exec query executing...") dbCursor = dbConnection.cursor() dbCursor.execute(query) diff --git a/main.py b/main.py index 85fe09f..3c016f0 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import sys from typing import Union from fastapi import FastAPI @@ -50,12 +51,18 @@ ID SERIAL PRIMARY KEY, Name VARCHAR(255) """) -dbHandler.insertRow(dbConnection, - 'users', - ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'], - ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256']) +#dbHandler.insertRow(dbConnection, +# 'users', +# ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'], +# ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256']) -print(dbHandler.execQuery(dbConnection, "SELECT * FROM users")) +print(dbHandler._execQuery(dbConnection, "SELECT * FROM users")) +print(dbHandler.getRowByID(dbConnection, "users", 5)) +print(dbHandler.getFieldByID(dbConnection, "users", 5, "description")) +if dbHandler.checkFieldValueExistence(dbConnection, "users", "username", "cspark"): + print("It exists!") +else: + print("It does NOT exist!") dbConnection.close() diff --git a/securityHandler.py b/securityHandler.py new file mode 100644 index 0000000..3fa7a39 --- /dev/null +++ b/securityHandler.py @@ -0,0 +1,50 @@ +import typing +import argon2 +import psycopg2 + +import dbHandler + +debug: bool = True + +passwordHasher = argon2.PasswordHasher() + +def debugPrint(msg: str) -> None: + if debug: + print("(SECURITY HANDLER) PRINT: " + msg) + +def hashPassword(password: str) -> str: + return passwordHasher.hash(password) + +def verifyPassword(password: str, hash: str) -> bool: + try: + if passwordHasher.verify(hash, password): + return True + else: + return False + except: + return False + +def verifyRehash(hash: str) -> bool: + try: + if passwordHasher.check_needs_rehash(hash): + return True + else: + return False + except: + return False + +def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool: + hash = dbHandler.getFieldByID(dbConnection, "users", userID, "passwordhash") + debugPrint("Now verifying password against hash for user ID " + userid + "...") + if verifyPassword(password, hash): + debugPrint("(USER ID) " + userID + " Password verification success!") + if verifyRehash(hash): + debugPrint("(USER ID) " + userID + " Hash needs to be rehashed! Will now rehash...") + return True + else: + debugPrint("(USER ID) " + userID + " Password verification failure!") + return False + +hashed: str = hashPassword("testing") +print(verifyPassword("testing", hashed)) +print(verifyRehash(hashed))