Move user specific functions to new userHandler module, setup most of logic for password authentication, init token handling module. Properly setup API init and cleanup functions

This commit is contained in:
Curt Spark 2024-04-20 16:59:47 +01:00
parent 8c91612a5c
commit dabf0c8977
5 changed files with 200 additions and 136 deletions

View File

@ -12,6 +12,9 @@ def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> No
if debug: if debug:
print("(DB HANDLER) " + dbConnection.notices[i]) print("(DB HANDLER) " + dbConnection.notices[i])
def errorPrint(msg: str) -> None:
print("(DB HANDLER) ERROR: " + msg)
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection: def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
debugPrint("Attempting to connect to database...") debugPrint("Attempting to connect to database...")
try: try:
@ -23,7 +26,7 @@ def connect(databaseOption: str, hostOption: str, userOption: str, passwordOptio
debugPrint("Successfully connected to database!") debugPrint("Successfully connected to database!")
return dbConnection return dbConnection
except: except:
debugPrint("Error occurred connecting to database! Exiting...") errorPrint("Error occurred connecting to database! Exiting...")
sys.exit(1) sys.exit(1)
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str): def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
@ -50,72 +53,6 @@ def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tabl
dbConnection.commit() dbConnection.commit()
dbCursor.close() dbCursor.close()
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list):
debugPrint("Attempting to insert new row...")
sanitisedQuery = sql.SQL("""
INSERT INTO {table} ({format})
VALUES ({values})
RETURNING *;
""").format(
table=sql.Identifier(tableName),
format=sql.SQL(", ").join(
sql.Identifier(value) for value in tableFormat
),
values=sql.SQL(", ").join(
sql.Literal(value) for value in tableValues
)
)
debugPrint(sanitisedQuery.as_string(dbConnection))
_commitQuery(dbConnection, sanitisedQuery)
def getFieldByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
debugPrint("Attempting to get field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT {field} FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
field=sql.Identifier(tableField),
id=sql.Literal(RowID)
)
return _execQuery(dbConnection, sanitisedQuery)[0][0]
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
debugPrint("Attempting to get row by ID " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT * FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
id=sql.Literal(RowID)
)
return _execQuery(dbConnection, sanitisedQuery)[0]
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
debugPrint("Attempting to get ID by username " + username + "...")
sanitisedQuery = sql.SQL("""
SELECT id FROM users WHERE "username" = {username}
""").format(
username=sql.Literal(username)
)
return _execQuery(dbConnection, sanitisedQuery)[0][0]
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
debugPrint("Checking if field name " + fieldName + " in " + tableName + " with value " + str(fieldValue) + " exists...")
sanitisedQuery = sql.SQL("""
SELECT EXISTS(
SELECT
{fieldName}
FROM
{table}
WHERE
{fieldName} = {fieldValue}
);
""").format(
table=sql.Identifier(tableName),
fieldName=sql.Identifier(fieldName),
fieldValue=sql.Literal(fieldValue)
)
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks) # These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
debugPrint("Commit query executing...") debugPrint("Commit query executing...")
@ -132,3 +69,62 @@ def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composab
dbResults = dbCursor.fetchall() dbResults = dbCursor.fetchall()
dbCursor.close() dbCursor.close()
return dbResults return dbResults
# Callable helper functions
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list):
debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
INSERT INTO {table} ({format})
VALUES ({values})
RETURNING *;
""").format(
table=sql.Identifier(tableName),
format=sql.SQL(", ").join(
sql.Identifier(value) for value in tableFormat
),
values=sql.SQL(", ").join(
sql.Literal(value) for value in tableValues
)
)
debugPrint(sanitisedQuery.as_string(dbConnection))
_commitQuery(dbConnection, sanitisedQuery)
def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT {field} FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
field=sql.Identifier(tableField),
id=sql.Literal(RowID)
)
return str(_execQuery(dbConnection, sanitisedQuery)[0][0])
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT * FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
id=sql.Literal(RowID)
)
return tuple(_execQuery(dbConnection, sanitisedQuery)[0])
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...")
sanitisedQuery = sql.SQL("""
SELECT EXISTS(
SELECT
{fieldName}
FROM
{table}
WHERE
{fieldName} = {fieldValue}
);
""").format(
table=sql.Identifier(tableName),
fieldName=sql.Identifier(fieldName),
fieldValue=sql.Literal(fieldValue)
)
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])

60
main.py
View File

@ -1,13 +1,20 @@
import sys import sys
import atexit
import signal
from typing import Union from typing import Union
from fastapi import FastAPI from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
import dbHandler import dbHandler
import userHandler
import securityHandler
dbConnection = None
def apiInit():
dbConnection = dbHandler.connect("blorgdb", dbConnection = dbHandler.connect("blorgdb",
"172.20.0.10", "172.20.0.10",
"dev", "dev",
@ -16,6 +23,7 @@ dbConnection = dbHandler.connect("blorgdb",
dbHandler.initTable(dbConnection, "Users", """ dbHandler.initTable(dbConnection, "Users", """
ID SERIAL PRIMARY KEY, ID SERIAL PRIMARY KEY,
Username VARCHAR(255), Username VARCHAR(255),
Email VARCHAR(255),
FirstName VARCHAR(255), FirstName VARCHAR(255),
LastName VARCHAR(255), LastName VARCHAR(255),
Description VARCHAR(255), Description VARCHAR(255),
@ -26,9 +34,9 @@ PasswordHash VARCHAR(255)
""") """)
dbHandler.initTable(dbConnection, "SignOns", """ dbHandler.initTable(dbConnection, "SignOns", """
ID SERIAL PRIMARY KEY, ID SERIAL PRIMARY KEY,
UserID VARCHAR(255), UserID INTEGER,
LoginSuccess BOOLEAN, LoginSuccess BOOLEAN,
DateAttempted VARCHAR(255), DateAttempted TIMESTAMP,
IPLocationAttempted VARCHAR(255) IPLocationAttempted VARCHAR(255)
""") """)
dbHandler.initTable(dbConnection, "AuthTokens", """ dbHandler.initTable(dbConnection, "AuthTokens", """
@ -50,23 +58,21 @@ dbHandler.initTable(dbConnection, "Categories", """
ID SERIAL PRIMARY KEY, ID SERIAL PRIMARY KEY,
Name VARCHAR(255) Name VARCHAR(255)
""") """)
userHandler.createUser(dbConnection, "testuser", "Test", "User", "A test user", "TestCountry", "TestTheme", "TestColor", "testuser")
#dbHandler.insertRow(dbConnection, def apiCleanup():
# 'users',
# ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
# ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256'])
print(dbHandler._execQuery(dbConnection, "SELECT * FROM users"))
print(dbHandler.getRowByID(dbConnection, "users", 5))
print(dbHandler.getFieldByID(dbConnection, "users", 5, "description"))
if dbHandler.checkFieldValueExistence(dbConnection, "users", "username", "cspark"):
print("It exists!")
else:
print("It does NOT exist!")
dbConnection.close() dbConnection.close()
app = FastAPI() @asynccontextmanager
async def apiLifespan(app: FastAPI):
# API Init
apiInit()
# API Clean up
yield
apiCleanup()
app = FastAPI(lifespan=apiLifespan)
origins = [ origins = [
"http://localhost", "http://localhost",
@ -88,13 +94,31 @@ def getroot():
class ApiBody(BaseModel): class ApiBody(BaseModel):
username: str username: str
password: str password: str
@app.post("/api") @app.post("/api")
def postapi(body: ApiBody): def postapi(body: ApiBody):
print(body.username) print(body.username)
print(body.password) print(body.password)
return body return body
class loginBody(BaseModel):
username: str
password: str
rememberMe: bool
@app.post("/api/login")
def postlogin(body: loginBody, request: Request):
try:
if userHandler.checkUserExistence(dbConnection, loginBody.username):
userID = userHandler.getIDByUsername(dbConnection, loginBody.username)
if securityHandler.handlePassword(dbConnection, loginBody.password, userID):
return {"success": True, "authToken": tokenHandler.createToken(dbConnection, userID, loginBody.rememberMe, request.client.host), "message": "User login success!"}
else:
return {"success": False, "authToken": "none", "message": "User login failed! Please check your password."}
else:
return {"success": False, "authToken": "none", "message": "User login failed! User does not exist."}
except:
return {"success": False, "authToken": "none", "message": "User login failed! Unexpected server error."}
@app.get("/api") @app.get("/api")
def getapi(): def getapi():
return {"Hello": "API!"} return {"Hello": "API!"}

View File

@ -3,6 +3,7 @@ import argon2
import psycopg2 import psycopg2
import dbHandler import dbHandler
import userHandler
debug: bool = True debug: bool = True
@ -34,17 +35,13 @@ def verifyRehash(hash: str) -> bool:
return False return False
def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool: def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool:
hash = dbHandler.getFieldByID(dbConnection, "users", userID, "passwordhash") hash = userHandler.getHashValueByUserID(dbConnection, userID)
debugPrint("Now verifying password against hash for user ID " + userid + "...") debugPrint("Now verifying password against hash for user ID " + userid + "...")
if verifyPassword(password, hash): if verifyPassword(password, hash):
debugPrint("(USER ID) " + userID + " Password verification success!") debugPrint("(USER ID " + userID + ") Password verification success!")
if verifyRehash(hash): if verifyRehash(hash):
debugPrint("(USER ID) " + userID + " Hash needs to be rehashed! Will now rehash...") debugPrint("(USER ID " + userID + ") Hash needs to be rehashed! Will now rehash...")
return True return True
else: else:
debugPrint("(USER ID) " + userID + " Password verification failure!") debugPrint("(USER ID " + userID + ") Password verification failure!")
return False return False
hashed: str = hashPassword("testing")
print(verifyPassword("testing", hashed))
print(verifyRehash(hashed))

14
tokenHandler.py Normal file
View File

@ -0,0 +1,14 @@
import psycopg2
from psycopg2 import sql
import dbHandler
debug: bool = True
def debugPrint(msg: str) -> None:
if debug:
print("(TOKEN HANDLER) PRINT: " + msg)
def createToken(dbConnection: psycopg2.extensions.connection, userID: int, rememberMe: bool, locationIP: str) -> str:
debugPrint("Now creating new token with following attributes : userID = " + str(userID) + ", rememberMe = " + str(rememberMe) + ", locationIP = " + locationIP + "...")
return "sha"

33
userHandler.py Normal file
View File

@ -0,0 +1,33 @@
import psycopg2
from psycopg2 import sql
import dbHandler
import securityHandler
debug: bool = True
def debugPrint(msg: str) -> None:
if debug:
print("(USER HANDLER) PRINT: " + msg)
def createUser(dbConnection: psycopg2.extensions.connection, username: str, firstName: str, lastName: str, description: str, country: str, theme: str, accentColor: str, password: str):
debugPrint("Now creating new user with following attributes : username = " + username + ", firstName = " + firstName + ", lastName = " + lastName + ", description = " + description + ", country = " + country + ", theme = " + theme + ", accentColor = " + accentColor)
dbHandler.insertRow(dbConnection,
'users',
['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
[username, firstName, lastName, description, country, theme, accentColor, securityHandler.hashPassword(password)])
def checkUserExistence(dbConnection: psycopg2.extensions.connection, username: str) -> bool:
return dbHandler.checkFieldValueExistence(dbConnection, "users", "username", username)
def getHashValuebyUserID(dbConnection: psycopg2.extensions.connection, userID: int) -> str:
return dbHandler.getFieldValueByID(dbConnection, "users", userID, "passwordhash")
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
debugPrint("Attempting to get ID of username " + username + "...")
sanitisedQuery = sql.SQL("""
SELECT id FROM users WHERE "username" = {username}
""").format(
username=sql.Literal(username)
)
return int(dbHandler._execQuery(dbConnection, sanitisedQuery)[0][0])