Implement basic error handling for database connect function, implement several new database handler helper functions, init security handler password hashing via argon

This commit is contained in:
Curt Spark 2024-04-19 16:12:16 +01:00
parent bcb15bcb28
commit 8c91612a5c
3 changed files with 128 additions and 12 deletions

View File

@ -2,18 +2,29 @@ import typing
import psycopg2
from psycopg2 import sql
debug: bool = True
def debugPrint(msg: str) -> None:
print(msg)
if debug:
print("(DB HANDLER) PRINT: " + msg)
def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> None:
print("(DB HANDLER) " + dbConnection.notices[i])
if debug:
print("(DB HANDLER) " + dbConnection.notices[i])
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
return psycopg2.connect(database=databaseOption,
debugPrint("Attempting to connect to database...")
try:
dbConnection = psycopg2.connect(database=databaseOption,
host=hostOption,
user=userOption,
password=passwordOption,
port=portOption)
debugPrint("Successfully connected to database!")
return dbConnection
except:
debugPrint("Error occurred connecting to database! Exiting...")
sys.exit(1)
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
dbCursor = dbConnection.cursor()
@ -55,9 +66,58 @@ def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tabl
)
)
debugPrint(sanitisedQuery.as_string(dbConnection))
commitQuery(dbConnection, sanitisedQuery)
_commitQuery(dbConnection, sanitisedQuery)
def commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
def getFieldByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
debugPrint("Attempting to get field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT {field} FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
field=sql.Identifier(tableField),
id=sql.Literal(RowID)
)
return _execQuery(dbConnection, sanitisedQuery)[0][0]
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
debugPrint("Attempting to get row by ID " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT * FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
id=sql.Literal(RowID)
)
return _execQuery(dbConnection, sanitisedQuery)[0]
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
debugPrint("Attempting to get ID by username " + username + "...")
sanitisedQuery = sql.SQL("""
SELECT id FROM users WHERE "username" = {username}
""").format(
username=sql.Literal(username)
)
return _execQuery(dbConnection, sanitisedQuery)[0][0]
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
debugPrint("Checking if field name " + fieldName + " in " + tableName + " with value " + str(fieldValue) + " exists...")
sanitisedQuery = sql.SQL("""
SELECT EXISTS(
SELECT
{fieldName}
FROM
{table}
WHERE
{fieldName} = {fieldValue}
);
""").format(
table=sql.Identifier(tableName),
fieldName=sql.Identifier(fieldName),
fieldValue=sql.Literal(fieldValue)
)
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
debugPrint("Commit query executing...")
dbCursor = dbConnection.cursor()
dbCursor.execute(query)
@ -65,8 +125,7 @@ def commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composa
dbResults = dbCursor.fetchall()
dbCursor.close()
return dbResults
def execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
debugPrint("Exec query executing...")
dbCursor = dbConnection.cursor()
dbCursor.execute(query)

17
main.py
View File

@ -1,3 +1,4 @@
import sys
from typing import Union
from fastapi import FastAPI
@ -50,12 +51,18 @@ ID SERIAL PRIMARY KEY,
Name VARCHAR(255)
""")
dbHandler.insertRow(dbConnection,
'users',
['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256'])
#dbHandler.insertRow(dbConnection,
# 'users',
# ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
# ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256'])
print(dbHandler.execQuery(dbConnection, "SELECT * FROM users"))
print(dbHandler._execQuery(dbConnection, "SELECT * FROM users"))
print(dbHandler.getRowByID(dbConnection, "users", 5))
print(dbHandler.getFieldByID(dbConnection, "users", 5, "description"))
if dbHandler.checkFieldValueExistence(dbConnection, "users", "username", "cspark"):
print("It exists!")
else:
print("It does NOT exist!")
dbConnection.close()

50
securityHandler.py Normal file
View File

@ -0,0 +1,50 @@
import typing
import argon2
import psycopg2
import dbHandler
debug: bool = True
passwordHasher = argon2.PasswordHasher()
def debugPrint(msg: str) -> None:
if debug:
print("(SECURITY HANDLER) PRINT: " + msg)
def hashPassword(password: str) -> str:
return passwordHasher.hash(password)
def verifyPassword(password: str, hash: str) -> bool:
try:
if passwordHasher.verify(hash, password):
return True
else:
return False
except:
return False
def verifyRehash(hash: str) -> bool:
try:
if passwordHasher.check_needs_rehash(hash):
return True
else:
return False
except:
return False
def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool:
hash = dbHandler.getFieldByID(dbConnection, "users", userID, "passwordhash")
debugPrint("Now verifying password against hash for user ID " + userid + "...")
if verifyPassword(password, hash):
debugPrint("(USER ID) " + userID + " Password verification success!")
if verifyRehash(hash):
debugPrint("(USER ID) " + userID + " Hash needs to be rehashed! Will now rehash...")
return True
else:
debugPrint("(USER ID) " + userID + " Password verification failure!")
return False
hashed: str = hashPassword("testing")
print(verifyPassword("testing", hashed))
print(verifyRehash(hashed))