Move user specific functions to new userHandler module, setup most of logic for password authentication, init token handling module. Properly setup API init and cleanup functions

This commit is contained in:
Curt Spark 2024-04-20 16:59:47 +01:00
parent 8c91612a5c
commit dabf0c8977
5 changed files with 200 additions and 136 deletions

View File

@ -12,18 +12,21 @@ def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> No
if debug:
print("(DB HANDLER) " + dbConnection.notices[i])
def errorPrint(msg: str) -> None:
print("(DB HANDLER) ERROR: " + msg)
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
debugPrint("Attempting to connect to database...")
try:
dbConnection = psycopg2.connect(database=databaseOption,
host=hostOption,
user=userOption,
password=passwordOption,
port=portOption)
host=hostOption,
user=userOption,
password=passwordOption,
port=portOption)
debugPrint("Successfully connected to database!")
return dbConnection
except:
debugPrint("Error occurred connecting to database! Exiting...")
errorPrint("Error occurred connecting to database! Exiting...")
sys.exit(1)
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
@ -50,72 +53,6 @@ def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tabl
dbConnection.commit()
dbCursor.close()
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list):
debugPrint("Attempting to insert new row...")
sanitisedQuery = sql.SQL("""
INSERT INTO {table} ({format})
VALUES ({values})
RETURNING *;
""").format(
table=sql.Identifier(tableName),
format=sql.SQL(", ").join(
sql.Identifier(value) for value in tableFormat
),
values=sql.SQL(", ").join(
sql.Literal(value) for value in tableValues
)
)
debugPrint(sanitisedQuery.as_string(dbConnection))
_commitQuery(dbConnection, sanitisedQuery)
def getFieldByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
debugPrint("Attempting to get field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT {field} FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
field=sql.Identifier(tableField),
id=sql.Literal(RowID)
)
return _execQuery(dbConnection, sanitisedQuery)[0][0]
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
debugPrint("Attempting to get row by ID " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT * FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
id=sql.Literal(RowID)
)
return _execQuery(dbConnection, sanitisedQuery)[0]
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
debugPrint("Attempting to get ID by username " + username + "...")
sanitisedQuery = sql.SQL("""
SELECT id FROM users WHERE "username" = {username}
""").format(
username=sql.Literal(username)
)
return _execQuery(dbConnection, sanitisedQuery)[0][0]
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
debugPrint("Checking if field name " + fieldName + " in " + tableName + " with value " + str(fieldValue) + " exists...")
sanitisedQuery = sql.SQL("""
SELECT EXISTS(
SELECT
{fieldName}
FROM
{table}
WHERE
{fieldName} = {fieldValue}
);
""").format(
table=sql.Identifier(tableName),
fieldName=sql.Identifier(fieldName),
fieldValue=sql.Literal(fieldValue)
)
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
debugPrint("Commit query executing...")
@ -132,3 +69,62 @@ def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composab
dbResults = dbCursor.fetchall()
dbCursor.close()
return dbResults
# Callable helper functions
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list):
debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
INSERT INTO {table} ({format})
VALUES ({values})
RETURNING *;
""").format(
table=sql.Identifier(tableName),
format=sql.SQL(", ").join(
sql.Identifier(value) for value in tableFormat
),
values=sql.SQL(", ").join(
sql.Literal(value) for value in tableValues
)
)
debugPrint(sanitisedQuery.as_string(dbConnection))
_commitQuery(dbConnection, sanitisedQuery)
def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT {field} FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
field=sql.Identifier(tableField),
id=sql.Literal(RowID)
)
return str(_execQuery(dbConnection, sanitisedQuery)[0][0])
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT * FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
id=sql.Literal(RowID)
)
return tuple(_execQuery(dbConnection, sanitisedQuery)[0])
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...")
sanitisedQuery = sql.SQL("""
SELECT EXISTS(
SELECT
{fieldName}
FROM
{table}
WHERE
{fieldName} = {fieldValue}
);
""").format(
table=sql.Identifier(tableName),
fieldName=sql.Identifier(fieldName),
fieldValue=sql.Literal(fieldValue)
)
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])

138
main.py
View File

@ -1,72 +1,78 @@
import sys
import atexit
import signal
from typing import Union
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import dbHandler
import userHandler
import securityHandler
dbConnection = dbHandler.connect("blorgdb",
"172.20.0.10",
"dev",
"dev",
"5432")
dbHandler.initTable(dbConnection, "Users", """
ID SERIAL PRIMARY KEY,
Username VARCHAR(255),
FirstName VARCHAR(255),
LastName VARCHAR(255),
Description VARCHAR(255),
Country VARCHAR(255),
Theme VARCHAR(255),
AccentColor VARCHAR(255),
PasswordHash VARCHAR(255)
""")
dbHandler.initTable(dbConnection, "SignOns", """
ID SERIAL PRIMARY KEY,
UserID VARCHAR(255),
LoginSuccess BOOLEAN,
DateAttempted VARCHAR(255),
IPLocationAttempted VARCHAR(255)
""")
dbHandler.initTable(dbConnection, "AuthTokens", """
ID SERIAL PRIMARY KEY,
Token VARCHAR(255),
OwnerID INTEGER,
DateCreated TIMESTAMP,
DateExpiry TIMESTAMP,
IPLocationCreated VARCHAR(255)
""")
dbHandler.initTable(dbConnection, "Blogs", """
ID SERIAL PRIMARY KEY,
AuthorID INTEGER,
CategoryID INTEGER,
DatePosted TIMESTAMP,
Description VARCHAR(255)
""")
dbHandler.initTable(dbConnection, "Categories", """
ID SERIAL PRIMARY KEY,
Name VARCHAR(255)
""")
dbConnection = None
#dbHandler.insertRow(dbConnection,
# 'users',
# ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
# ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256'])
def apiInit():
dbConnection = dbHandler.connect("blorgdb",
"172.20.0.10",
"dev",
"dev",
"5432")
dbHandler.initTable(dbConnection, "Users", """
ID SERIAL PRIMARY KEY,
Username VARCHAR(255),
Email VARCHAR(255),
FirstName VARCHAR(255),
LastName VARCHAR(255),
Description VARCHAR(255),
Country VARCHAR(255),
Theme VARCHAR(255),
AccentColor VARCHAR(255),
PasswordHash VARCHAR(255)
""")
dbHandler.initTable(dbConnection, "SignOns", """
ID SERIAL PRIMARY KEY,
UserID INTEGER,
LoginSuccess BOOLEAN,
DateAttempted TIMESTAMP,
IPLocationAttempted VARCHAR(255)
""")
dbHandler.initTable(dbConnection, "AuthTokens", """
ID SERIAL PRIMARY KEY,
Token VARCHAR(255),
OwnerID INTEGER,
DateCreated TIMESTAMP,
DateExpiry TIMESTAMP,
IPLocationCreated VARCHAR(255)
""")
dbHandler.initTable(dbConnection, "Blogs", """
ID SERIAL PRIMARY KEY,
AuthorID INTEGER,
CategoryID INTEGER,
DatePosted TIMESTAMP,
Description VARCHAR(255)
""")
dbHandler.initTable(dbConnection, "Categories", """
ID SERIAL PRIMARY KEY,
Name VARCHAR(255)
""")
userHandler.createUser(dbConnection, "testuser", "Test", "User", "A test user", "TestCountry", "TestTheme", "TestColor", "testuser")
print(dbHandler._execQuery(dbConnection, "SELECT * FROM users"))
print(dbHandler.getRowByID(dbConnection, "users", 5))
print(dbHandler.getFieldByID(dbConnection, "users", 5, "description"))
if dbHandler.checkFieldValueExistence(dbConnection, "users", "username", "cspark"):
print("It exists!")
else:
print("It does NOT exist!")
def apiCleanup():
dbConnection.close()
dbConnection.close()
@asynccontextmanager
async def apiLifespan(app: FastAPI):
# API Init
apiInit()
# API Clean up
yield
apiCleanup()
app = FastAPI()
app = FastAPI(lifespan=apiLifespan)
origins = [
"http://localhost",
@ -88,13 +94,31 @@ def getroot():
class ApiBody(BaseModel):
username: str
password: str
@app.post("/api")
def postapi(body: ApiBody):
print(body.username)
print(body.password)
return body
class loginBody(BaseModel):
username: str
password: str
rememberMe: bool
@app.post("/api/login")
def postlogin(body: loginBody, request: Request):
try:
if userHandler.checkUserExistence(dbConnection, loginBody.username):
userID = userHandler.getIDByUsername(dbConnection, loginBody.username)
if securityHandler.handlePassword(dbConnection, loginBody.password, userID):
return {"success": True, "authToken": tokenHandler.createToken(dbConnection, userID, loginBody.rememberMe, request.client.host), "message": "User login success!"}
else:
return {"success": False, "authToken": "none", "message": "User login failed! Please check your password."}
else:
return {"success": False, "authToken": "none", "message": "User login failed! User does not exist."}
except:
return {"success": False, "authToken": "none", "message": "User login failed! Unexpected server error."}
@app.get("/api")
def getapi():
return {"Hello": "API!"}

View File

@ -3,6 +3,7 @@ import argon2
import psycopg2
import dbHandler
import userHandler
debug: bool = True
@ -34,17 +35,13 @@ def verifyRehash(hash: str) -> bool:
return False
def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool:
hash = dbHandler.getFieldByID(dbConnection, "users", userID, "passwordhash")
hash = userHandler.getHashValueByUserID(dbConnection, userID)
debugPrint("Now verifying password against hash for user ID " + userid + "...")
if verifyPassword(password, hash):
debugPrint("(USER ID) " + userID + " Password verification success!")
debugPrint("(USER ID " + userID + ") Password verification success!")
if verifyRehash(hash):
debugPrint("(USER ID) " + userID + " Hash needs to be rehashed! Will now rehash...")
debugPrint("(USER ID " + userID + ") Hash needs to be rehashed! Will now rehash...")
return True
else:
debugPrint("(USER ID) " + userID + " Password verification failure!")
debugPrint("(USER ID " + userID + ") Password verification failure!")
return False
hashed: str = hashPassword("testing")
print(verifyPassword("testing", hashed))
print(verifyRehash(hashed))

14
tokenHandler.py Normal file
View File

@ -0,0 +1,14 @@
import psycopg2
from psycopg2 import sql
import dbHandler
debug: bool = True
def debugPrint(msg: str) -> None:
if debug:
print("(TOKEN HANDLER) PRINT: " + msg)
def createToken(dbConnection: psycopg2.extensions.connection, userID: int, rememberMe: bool, locationIP: str) -> str:
debugPrint("Now creating new token with following attributes : userID = " + str(userID) + ", rememberMe = " + str(rememberMe) + ", locationIP = " + locationIP + "...")
return "sha"

33
userHandler.py Normal file
View File

@ -0,0 +1,33 @@
import psycopg2
from psycopg2 import sql
import dbHandler
import securityHandler
debug: bool = True
def debugPrint(msg: str) -> None:
if debug:
print("(USER HANDLER) PRINT: " + msg)
def createUser(dbConnection: psycopg2.extensions.connection, username: str, firstName: str, lastName: str, description: str, country: str, theme: str, accentColor: str, password: str):
debugPrint("Now creating new user with following attributes : username = " + username + ", firstName = " + firstName + ", lastName = " + lastName + ", description = " + description + ", country = " + country + ", theme = " + theme + ", accentColor = " + accentColor)
dbHandler.insertRow(dbConnection,
'users',
['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
[username, firstName, lastName, description, country, theme, accentColor, securityHandler.hashPassword(password)])
def checkUserExistence(dbConnection: psycopg2.extensions.connection, username: str) -> bool:
return dbHandler.checkFieldValueExistence(dbConnection, "users", "username", username)
def getHashValuebyUserID(dbConnection: psycopg2.extensions.connection, userID: int) -> str:
return dbHandler.getFieldValueByID(dbConnection, "users", userID, "passwordhash")
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
debugPrint("Attempting to get ID of username " + username + "...")
sanitisedQuery = sql.SQL("""
SELECT id FROM users WHERE "username" = {username}
""").format(
username=sql.Literal(username)
)
return int(dbHandler._execQuery(dbConnection, sanitisedQuery)[0][0])