Implement basic error handling for database connect function, implement several new database handler helper functions, init security handler password hashing via argon
This commit is contained in:
parent
bcb15bcb28
commit
8c91612a5c
73
dbHandler.py
73
dbHandler.py
|
|
@ -2,18 +2,29 @@ import typing
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2 import sql
|
from psycopg2 import sql
|
||||||
|
|
||||||
|
debug: bool = True
|
||||||
|
|
||||||
def debugPrint(msg: str) -> None:
|
def debugPrint(msg: str) -> None:
|
||||||
print(msg)
|
if debug:
|
||||||
|
print("(DB HANDLER) PRINT: " + msg)
|
||||||
|
|
||||||
def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> None:
|
def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> None:
|
||||||
print("(DB HANDLER) " + dbConnection.notices[i])
|
if debug:
|
||||||
|
print("(DB HANDLER) " + dbConnection.notices[i])
|
||||||
|
|
||||||
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
|
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
|
||||||
return psycopg2.connect(database=databaseOption,
|
debugPrint("Attempting to connect to database...")
|
||||||
|
try:
|
||||||
|
dbConnection = psycopg2.connect(database=databaseOption,
|
||||||
host=hostOption,
|
host=hostOption,
|
||||||
user=userOption,
|
user=userOption,
|
||||||
password=passwordOption,
|
password=passwordOption,
|
||||||
port=portOption)
|
port=portOption)
|
||||||
|
debugPrint("Successfully connected to database!")
|
||||||
|
return dbConnection
|
||||||
|
except:
|
||||||
|
debugPrint("Error occurred connecting to database! Exiting...")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
|
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
|
||||||
dbCursor = dbConnection.cursor()
|
dbCursor = dbConnection.cursor()
|
||||||
|
|
@ -55,9 +66,58 @@ def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tabl
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
debugPrint(sanitisedQuery.as_string(dbConnection))
|
debugPrint(sanitisedQuery.as_string(dbConnection))
|
||||||
commitQuery(dbConnection, sanitisedQuery)
|
_commitQuery(dbConnection, sanitisedQuery)
|
||||||
|
|
||||||
def commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
|
def getFieldByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
|
||||||
|
debugPrint("Attempting to get field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
SELECT {field} FROM {table} WHERE "id" = {id}
|
||||||
|
""").format(
|
||||||
|
table=sql.Identifier(tableName),
|
||||||
|
field=sql.Identifier(tableField),
|
||||||
|
id=sql.Literal(RowID)
|
||||||
|
)
|
||||||
|
return _execQuery(dbConnection, sanitisedQuery)[0][0]
|
||||||
|
|
||||||
|
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
|
||||||
|
debugPrint("Attempting to get row by ID " + str(RowID) + " in table name " + tableName + "...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
SELECT * FROM {table} WHERE "id" = {id}
|
||||||
|
""").format(
|
||||||
|
table=sql.Identifier(tableName),
|
||||||
|
id=sql.Literal(RowID)
|
||||||
|
)
|
||||||
|
return _execQuery(dbConnection, sanitisedQuery)[0]
|
||||||
|
|
||||||
|
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
|
||||||
|
debugPrint("Attempting to get ID by username " + username + "...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
SELECT id FROM users WHERE "username" = {username}
|
||||||
|
""").format(
|
||||||
|
username=sql.Literal(username)
|
||||||
|
)
|
||||||
|
return _execQuery(dbConnection, sanitisedQuery)[0][0]
|
||||||
|
|
||||||
|
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
|
||||||
|
debugPrint("Checking if field name " + fieldName + " in " + tableName + " with value " + str(fieldValue) + " exists...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT
|
||||||
|
{fieldName}
|
||||||
|
FROM
|
||||||
|
{table}
|
||||||
|
WHERE
|
||||||
|
{fieldName} = {fieldValue}
|
||||||
|
);
|
||||||
|
""").format(
|
||||||
|
table=sql.Identifier(tableName),
|
||||||
|
fieldName=sql.Identifier(fieldName),
|
||||||
|
fieldValue=sql.Literal(fieldValue)
|
||||||
|
)
|
||||||
|
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
||||||
|
|
||||||
|
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
|
||||||
|
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
|
||||||
debugPrint("Commit query executing...")
|
debugPrint("Commit query executing...")
|
||||||
dbCursor = dbConnection.cursor()
|
dbCursor = dbConnection.cursor()
|
||||||
dbCursor.execute(query)
|
dbCursor.execute(query)
|
||||||
|
|
@ -65,8 +125,7 @@ def commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composa
|
||||||
dbResults = dbCursor.fetchall()
|
dbResults = dbCursor.fetchall()
|
||||||
dbCursor.close()
|
dbCursor.close()
|
||||||
return dbResults
|
return dbResults
|
||||||
|
def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
|
||||||
def execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
|
|
||||||
debugPrint("Exec query executing...")
|
debugPrint("Exec query executing...")
|
||||||
dbCursor = dbConnection.cursor()
|
dbCursor = dbConnection.cursor()
|
||||||
dbCursor.execute(query)
|
dbCursor.execute(query)
|
||||||
|
|
|
||||||
17
main.py
17
main.py
|
|
@ -1,3 +1,4 @@
|
||||||
|
import sys
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
@ -50,12 +51,18 @@ ID SERIAL PRIMARY KEY,
|
||||||
Name VARCHAR(255)
|
Name VARCHAR(255)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
dbHandler.insertRow(dbConnection,
|
#dbHandler.insertRow(dbConnection,
|
||||||
'users',
|
# 'users',
|
||||||
['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
|
# ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
|
||||||
['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256'])
|
# ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256'])
|
||||||
|
|
||||||
print(dbHandler.execQuery(dbConnection, "SELECT * FROM users"))
|
print(dbHandler._execQuery(dbConnection, "SELECT * FROM users"))
|
||||||
|
print(dbHandler.getRowByID(dbConnection, "users", 5))
|
||||||
|
print(dbHandler.getFieldByID(dbConnection, "users", 5, "description"))
|
||||||
|
if dbHandler.checkFieldValueExistence(dbConnection, "users", "username", "cspark"):
|
||||||
|
print("It exists!")
|
||||||
|
else:
|
||||||
|
print("It does NOT exist!")
|
||||||
|
|
||||||
dbConnection.close()
|
dbConnection.close()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
import typing
|
||||||
|
import argon2
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
import dbHandler
|
||||||
|
|
||||||
|
debug: bool = True
|
||||||
|
|
||||||
|
passwordHasher = argon2.PasswordHasher()
|
||||||
|
|
||||||
|
def debugPrint(msg: str) -> None:
|
||||||
|
if debug:
|
||||||
|
print("(SECURITY HANDLER) PRINT: " + msg)
|
||||||
|
|
||||||
|
def hashPassword(password: str) -> str:
|
||||||
|
return passwordHasher.hash(password)
|
||||||
|
|
||||||
|
def verifyPassword(password: str, hash: str) -> bool:
|
||||||
|
try:
|
||||||
|
if passwordHasher.verify(hash, password):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def verifyRehash(hash: str) -> bool:
|
||||||
|
try:
|
||||||
|
if passwordHasher.check_needs_rehash(hash):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool:
|
||||||
|
hash = dbHandler.getFieldByID(dbConnection, "users", userID, "passwordhash")
|
||||||
|
debugPrint("Now verifying password against hash for user ID " + userid + "...")
|
||||||
|
if verifyPassword(password, hash):
|
||||||
|
debugPrint("(USER ID) " + userID + " Password verification success!")
|
||||||
|
if verifyRehash(hash):
|
||||||
|
debugPrint("(USER ID) " + userID + " Hash needs to be rehashed! Will now rehash...")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
debugPrint("(USER ID) " + userID + " Password verification failure!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
hashed: str = hashPassword("testing")
|
||||||
|
print(verifyPassword("testing", hashed))
|
||||||
|
print(verifyRehash(hashed))
|
||||||
Loading…
Reference in New Issue