diff --git a/dbHandler.py b/dbHandler.py index 4bc5cd9..f2f208e 100644 --- a/dbHandler.py +++ b/dbHandler.py @@ -12,18 +12,21 @@ def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> No if debug: print("(DB HANDLER) " + dbConnection.notices[i]) +def errorPrint(msg: str) -> None: + print("(DB HANDLER) ERROR: " + msg) + def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection: debugPrint("Attempting to connect to database...") try: dbConnection = psycopg2.connect(database=databaseOption, - host=hostOption, - user=userOption, - password=passwordOption, - port=portOption) + host=hostOption, + user=userOption, + password=passwordOption, + port=portOption) debugPrint("Successfully connected to database!") return dbConnection except: - debugPrint("Error occurred connecting to database! Exiting...") + errorPrint("Error occurred connecting to database! Exiting...") sys.exit(1) def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str): @@ -50,72 +53,6 @@ def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tabl dbConnection.commit() dbCursor.close() -def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list): - debugPrint("Attempting to insert new row...") - sanitisedQuery = sql.SQL(""" - INSERT INTO {table} ({format}) - VALUES ({values}) - RETURNING *; - """).format( - table=sql.Identifier(tableName), - format=sql.SQL(", ").join( - sql.Identifier(value) for value in tableFormat - ), - values=sql.SQL(", ").join( - sql.Literal(value) for value in tableValues - ) - ) - debugPrint(sanitisedQuery.as_string(dbConnection)) - _commitQuery(dbConnection, sanitisedQuery) - -def getFieldByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str: - debugPrint("Attempting to get field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...") - sanitisedQuery = sql.SQL(""" - SELECT {field} FROM {table} WHERE "id" = {id} - """).format( - table=sql.Identifier(tableName), - field=sql.Identifier(tableField), - id=sql.Literal(RowID) - ) - return _execQuery(dbConnection, sanitisedQuery)[0][0] - -def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple: - debugPrint("Attempting to get row by ID " + str(RowID) + " in table name " + tableName + "...") - sanitisedQuery = sql.SQL(""" - SELECT * FROM {table} WHERE "id" = {id} - """).format( - table=sql.Identifier(tableName), - id=sql.Literal(RowID) - ) - return _execQuery(dbConnection, sanitisedQuery)[0] - -def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int: - debugPrint("Attempting to get ID by username " + username + "...") - sanitisedQuery = sql.SQL(""" - SELECT id FROM users WHERE "username" = {username} - """).format( - username=sql.Literal(username) - ) - return _execQuery(dbConnection, sanitisedQuery)[0][0] - -def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool: - debugPrint("Checking if field name " + fieldName + " in " + tableName + " with value " + str(fieldValue) + " exists...") - sanitisedQuery = sql.SQL(""" - SELECT EXISTS( - SELECT - {fieldName} - FROM - {table} - WHERE - {fieldName} = {fieldValue} - ); - """).format( - table=sql.Identifier(tableName), - fieldName=sql.Identifier(fieldName), - fieldValue=sql.Literal(fieldValue) - ) - return bool(_execQuery(dbConnection, sanitisedQuery)[0][0]) - # These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks) def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: debugPrint("Commit query executing...") @@ -132,3 +69,62 @@ def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composab dbResults = dbCursor.fetchall() dbCursor.close() return dbResults + +# Callable helper functions +def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list): + debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...") + sanitisedQuery = sql.SQL(""" + INSERT INTO {table} ({format}) + VALUES ({values}) + RETURNING *; + """).format( + table=sql.Identifier(tableName), + format=sql.SQL(", ").join( + sql.Identifier(value) for value in tableFormat + ), + values=sql.SQL(", ").join( + sql.Literal(value) for value in tableValues + ) + ) + debugPrint(sanitisedQuery.as_string(dbConnection)) + _commitQuery(dbConnection, sanitisedQuery) + +def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str: + debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...") + sanitisedQuery = sql.SQL(""" + SELECT {field} FROM {table} WHERE "id" = {id} + """).format( + table=sql.Identifier(tableName), + field=sql.Identifier(tableField), + id=sql.Literal(RowID) + ) + return str(_execQuery(dbConnection, sanitisedQuery)[0][0]) + +def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple: + debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...") + sanitisedQuery = sql.SQL(""" + SELECT * FROM {table} WHERE "id" = {id} + """).format( + table=sql.Identifier(tableName), + id=sql.Literal(RowID) + ) + return tuple(_execQuery(dbConnection, sanitisedQuery)[0]) + +def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool: + debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...") + sanitisedQuery = sql.SQL(""" + SELECT EXISTS( + SELECT + {fieldName} + FROM + {table} + WHERE + {fieldName} = {fieldValue} + ); + """).format( + table=sql.Identifier(tableName), + fieldName=sql.Identifier(fieldName), + fieldValue=sql.Literal(fieldValue) + ) + return bool(_execQuery(dbConnection, sanitisedQuery)[0][0]) + diff --git a/main.py b/main.py index 3c016f0..993f8d3 100644 --- a/main.py +++ b/main.py @@ -1,72 +1,78 @@ import sys +import atexit +import signal from typing import Union -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import dbHandler +import userHandler +import securityHandler -dbConnection = dbHandler.connect("blorgdb", - "172.20.0.10", - "dev", - "dev", - "5432") -dbHandler.initTable(dbConnection, "Users", """ -ID SERIAL PRIMARY KEY, -Username VARCHAR(255), -FirstName VARCHAR(255), -LastName VARCHAR(255), -Description VARCHAR(255), -Country VARCHAR(255), -Theme VARCHAR(255), -AccentColor VARCHAR(255), -PasswordHash VARCHAR(255) -""") -dbHandler.initTable(dbConnection, "SignOns", """ -ID SERIAL PRIMARY KEY, -UserID VARCHAR(255), -LoginSuccess BOOLEAN, -DateAttempted VARCHAR(255), -IPLocationAttempted VARCHAR(255) -""") -dbHandler.initTable(dbConnection, "AuthTokens", """ -ID SERIAL PRIMARY KEY, -Token VARCHAR(255), -OwnerID INTEGER, -DateCreated TIMESTAMP, -DateExpiry TIMESTAMP, -IPLocationCreated VARCHAR(255) -""") -dbHandler.initTable(dbConnection, "Blogs", """ -ID SERIAL PRIMARY KEY, -AuthorID INTEGER, -CategoryID INTEGER, -DatePosted TIMESTAMP, -Description VARCHAR(255) -""") -dbHandler.initTable(dbConnection, "Categories", """ -ID SERIAL PRIMARY KEY, -Name VARCHAR(255) -""") +dbConnection = None -#dbHandler.insertRow(dbConnection, -# 'users', -# ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'], -# ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256']) +def apiInit(): + dbConnection = dbHandler.connect("blorgdb", + "172.20.0.10", + "dev", + "dev", + "5432") + dbHandler.initTable(dbConnection, "Users", """ + ID SERIAL PRIMARY KEY, + Username VARCHAR(255), + Email VARCHAR(255), + FirstName VARCHAR(255), + LastName VARCHAR(255), + Description VARCHAR(255), + Country VARCHAR(255), + Theme VARCHAR(255), + AccentColor VARCHAR(255), + PasswordHash VARCHAR(255) + """) + dbHandler.initTable(dbConnection, "SignOns", """ + ID SERIAL PRIMARY KEY, + UserID INTEGER, + LoginSuccess BOOLEAN, + DateAttempted TIMESTAMP, + IPLocationAttempted VARCHAR(255) + """) + dbHandler.initTable(dbConnection, "AuthTokens", """ + ID SERIAL PRIMARY KEY, + Token VARCHAR(255), + OwnerID INTEGER, + DateCreated TIMESTAMP, + DateExpiry TIMESTAMP, + IPLocationCreated VARCHAR(255) + """) + dbHandler.initTable(dbConnection, "Blogs", """ + ID SERIAL PRIMARY KEY, + AuthorID INTEGER, + CategoryID INTEGER, + DatePosted TIMESTAMP, + Description VARCHAR(255) + """) + dbHandler.initTable(dbConnection, "Categories", """ + ID SERIAL PRIMARY KEY, + Name VARCHAR(255) + """) + userHandler.createUser(dbConnection, "testuser", "Test", "User", "A test user", "TestCountry", "TestTheme", "TestColor", "testuser") -print(dbHandler._execQuery(dbConnection, "SELECT * FROM users")) -print(dbHandler.getRowByID(dbConnection, "users", 5)) -print(dbHandler.getFieldByID(dbConnection, "users", 5, "description")) -if dbHandler.checkFieldValueExistence(dbConnection, "users", "username", "cspark"): - print("It exists!") -else: - print("It does NOT exist!") +def apiCleanup(): + dbConnection.close() -dbConnection.close() +@asynccontextmanager +async def apiLifespan(app: FastAPI): + # API Init + apiInit() + + # API Clean up + yield + apiCleanup() -app = FastAPI() +app = FastAPI(lifespan=apiLifespan) origins = [ "http://localhost", @@ -88,13 +94,31 @@ def getroot(): class ApiBody(BaseModel): username: str password: str - @app.post("/api") def postapi(body: ApiBody): print(body.username) print(body.password) return body +class loginBody(BaseModel): + username: str + password: str + rememberMe: bool +@app.post("/api/login") +def postlogin(body: loginBody, request: Request): + try: + if userHandler.checkUserExistence(dbConnection, loginBody.username): + userID = userHandler.getIDByUsername(dbConnection, loginBody.username) + if securityHandler.handlePassword(dbConnection, loginBody.password, userID): + return {"success": True, "authToken": tokenHandler.createToken(dbConnection, userID, loginBody.rememberMe, request.client.host), "message": "User login success!"} + else: + return {"success": False, "authToken": "none", "message": "User login failed! Please check your password."} + else: + return {"success": False, "authToken": "none", "message": "User login failed! User does not exist."} + except: + return {"success": False, "authToken": "none", "message": "User login failed! Unexpected server error."} + @app.get("/api") def getapi(): return {"Hello": "API!"} + diff --git a/securityHandler.py b/securityHandler.py index 3fa7a39..1d74183 100644 --- a/securityHandler.py +++ b/securityHandler.py @@ -3,6 +3,7 @@ import argon2 import psycopg2 import dbHandler +import userHandler debug: bool = True @@ -34,17 +35,13 @@ def verifyRehash(hash: str) -> bool: return False def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool: - hash = dbHandler.getFieldByID(dbConnection, "users", userID, "passwordhash") + hash = userHandler.getHashValueByUserID(dbConnection, userID) debugPrint("Now verifying password against hash for user ID " + userid + "...") if verifyPassword(password, hash): - debugPrint("(USER ID) " + userID + " Password verification success!") + debugPrint("(USER ID " + userID + ") Password verification success!") if verifyRehash(hash): - debugPrint("(USER ID) " + userID + " Hash needs to be rehashed! Will now rehash...") + debugPrint("(USER ID " + userID + ") Hash needs to be rehashed! Will now rehash...") return True else: - debugPrint("(USER ID) " + userID + " Password verification failure!") + debugPrint("(USER ID " + userID + ") Password verification failure!") return False - -hashed: str = hashPassword("testing") -print(verifyPassword("testing", hashed)) -print(verifyRehash(hashed)) diff --git a/tokenHandler.py b/tokenHandler.py new file mode 100644 index 0000000..5788058 --- /dev/null +++ b/tokenHandler.py @@ -0,0 +1,14 @@ +import psycopg2 +from psycopg2 import sql + +import dbHandler + +debug: bool = True + +def debugPrint(msg: str) -> None: + if debug: + print("(TOKEN HANDLER) PRINT: " + msg) + +def createToken(dbConnection: psycopg2.extensions.connection, userID: int, rememberMe: bool, locationIP: str) -> str: + debugPrint("Now creating new token with following attributes : userID = " + str(userID) + ", rememberMe = " + str(rememberMe) + ", locationIP = " + locationIP + "...") + return "sha" diff --git a/userHandler.py b/userHandler.py new file mode 100644 index 0000000..eb1bd73 --- /dev/null +++ b/userHandler.py @@ -0,0 +1,33 @@ +import psycopg2 +from psycopg2 import sql + +import dbHandler +import securityHandler + +debug: bool = True + +def debugPrint(msg: str) -> None: + if debug: + print("(USER HANDLER) PRINT: " + msg) + +def createUser(dbConnection: psycopg2.extensions.connection, username: str, firstName: str, lastName: str, description: str, country: str, theme: str, accentColor: str, password: str): + debugPrint("Now creating new user with following attributes : username = " + username + ", firstName = " + firstName + ", lastName = " + lastName + ", description = " + description + ", country = " + country + ", theme = " + theme + ", accentColor = " + accentColor) + dbHandler.insertRow(dbConnection, + 'users', + ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'], + [username, firstName, lastName, description, country, theme, accentColor, securityHandler.hashPassword(password)]) + +def checkUserExistence(dbConnection: psycopg2.extensions.connection, username: str) -> bool: + return dbHandler.checkFieldValueExistence(dbConnection, "users", "username", username) + +def getHashValuebyUserID(dbConnection: psycopg2.extensions.connection, userID: int) -> str: + return dbHandler.getFieldValueByID(dbConnection, "users", userID, "passwordhash") + +def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int: + debugPrint("Attempting to get ID of username " + username + "...") + sanitisedQuery = sql.SQL(""" + SELECT id FROM users WHERE "username" = {username} + """).format( + username=sql.Literal(username) + ) + return int(dbHandler._execQuery(dbConnection, sanitisedQuery)[0][0])