Move user specific functions to new userHandler module, setup most of logic for password authentication, init token handling module. Properly setup API init and cleanup functions
This commit is contained in:
parent
8c91612a5c
commit
dabf0c8977
138
dbHandler.py
138
dbHandler.py
|
|
@ -12,18 +12,21 @@ def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> No
|
|||
if debug:
|
||||
print("(DB HANDLER) " + dbConnection.notices[i])
|
||||
|
||||
def errorPrint(msg: str) -> None:
|
||||
print("(DB HANDLER) ERROR: " + msg)
|
||||
|
||||
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
|
||||
debugPrint("Attempting to connect to database...")
|
||||
try:
|
||||
dbConnection = psycopg2.connect(database=databaseOption,
|
||||
host=hostOption,
|
||||
user=userOption,
|
||||
password=passwordOption,
|
||||
port=portOption)
|
||||
host=hostOption,
|
||||
user=userOption,
|
||||
password=passwordOption,
|
||||
port=portOption)
|
||||
debugPrint("Successfully connected to database!")
|
||||
return dbConnection
|
||||
except:
|
||||
debugPrint("Error occurred connecting to database! Exiting...")
|
||||
errorPrint("Error occurred connecting to database! Exiting...")
|
||||
sys.exit(1)
|
||||
|
||||
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
|
||||
|
|
@ -50,72 +53,6 @@ def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tabl
|
|||
dbConnection.commit()
|
||||
dbCursor.close()
|
||||
|
||||
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list):
|
||||
debugPrint("Attempting to insert new row...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
INSERT INTO {table} ({format})
|
||||
VALUES ({values})
|
||||
RETURNING *;
|
||||
""").format(
|
||||
table=sql.Identifier(tableName),
|
||||
format=sql.SQL(", ").join(
|
||||
sql.Identifier(value) for value in tableFormat
|
||||
),
|
||||
values=sql.SQL(", ").join(
|
||||
sql.Literal(value) for value in tableValues
|
||||
)
|
||||
)
|
||||
debugPrint(sanitisedQuery.as_string(dbConnection))
|
||||
_commitQuery(dbConnection, sanitisedQuery)
|
||||
|
||||
def getFieldByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
|
||||
debugPrint("Attempting to get field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
SELECT {field} FROM {table} WHERE "id" = {id}
|
||||
""").format(
|
||||
table=sql.Identifier(tableName),
|
||||
field=sql.Identifier(tableField),
|
||||
id=sql.Literal(RowID)
|
||||
)
|
||||
return _execQuery(dbConnection, sanitisedQuery)[0][0]
|
||||
|
||||
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
|
||||
debugPrint("Attempting to get row by ID " + str(RowID) + " in table name " + tableName + "...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
SELECT * FROM {table} WHERE "id" = {id}
|
||||
""").format(
|
||||
table=sql.Identifier(tableName),
|
||||
id=sql.Literal(RowID)
|
||||
)
|
||||
return _execQuery(dbConnection, sanitisedQuery)[0]
|
||||
|
||||
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
|
||||
debugPrint("Attempting to get ID by username " + username + "...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
SELECT id FROM users WHERE "username" = {username}
|
||||
""").format(
|
||||
username=sql.Literal(username)
|
||||
)
|
||||
return _execQuery(dbConnection, sanitisedQuery)[0][0]
|
||||
|
||||
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
|
||||
debugPrint("Checking if field name " + fieldName + " in " + tableName + " with value " + str(fieldValue) + " exists...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
SELECT EXISTS(
|
||||
SELECT
|
||||
{fieldName}
|
||||
FROM
|
||||
{table}
|
||||
WHERE
|
||||
{fieldName} = {fieldValue}
|
||||
);
|
||||
""").format(
|
||||
table=sql.Identifier(tableName),
|
||||
fieldName=sql.Identifier(fieldName),
|
||||
fieldValue=sql.Literal(fieldValue)
|
||||
)
|
||||
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
||||
|
||||
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
|
||||
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
|
||||
debugPrint("Commit query executing...")
|
||||
|
|
@ -132,3 +69,62 @@ def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composab
|
|||
dbResults = dbCursor.fetchall()
|
||||
dbCursor.close()
|
||||
return dbResults
|
||||
|
||||
# Callable helper functions
|
||||
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list):
|
||||
debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
INSERT INTO {table} ({format})
|
||||
VALUES ({values})
|
||||
RETURNING *;
|
||||
""").format(
|
||||
table=sql.Identifier(tableName),
|
||||
format=sql.SQL(", ").join(
|
||||
sql.Identifier(value) for value in tableFormat
|
||||
),
|
||||
values=sql.SQL(", ").join(
|
||||
sql.Literal(value) for value in tableValues
|
||||
)
|
||||
)
|
||||
debugPrint(sanitisedQuery.as_string(dbConnection))
|
||||
_commitQuery(dbConnection, sanitisedQuery)
|
||||
|
||||
def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
|
||||
debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
SELECT {field} FROM {table} WHERE "id" = {id}
|
||||
""").format(
|
||||
table=sql.Identifier(tableName),
|
||||
field=sql.Identifier(tableField),
|
||||
id=sql.Literal(RowID)
|
||||
)
|
||||
return str(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
||||
|
||||
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
|
||||
debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
SELECT * FROM {table} WHERE "id" = {id}
|
||||
""").format(
|
||||
table=sql.Identifier(tableName),
|
||||
id=sql.Literal(RowID)
|
||||
)
|
||||
return tuple(_execQuery(dbConnection, sanitisedQuery)[0])
|
||||
|
||||
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
|
||||
debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
SELECT EXISTS(
|
||||
SELECT
|
||||
{fieldName}
|
||||
FROM
|
||||
{table}
|
||||
WHERE
|
||||
{fieldName} = {fieldValue}
|
||||
);
|
||||
""").format(
|
||||
table=sql.Identifier(tableName),
|
||||
fieldName=sql.Identifier(fieldName),
|
||||
fieldValue=sql.Literal(fieldValue)
|
||||
)
|
||||
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
||||
|
||||
|
|
|
|||
138
main.py
138
main.py
|
|
@ -1,72 +1,78 @@
|
|||
import sys
|
||||
import atexit
|
||||
import signal
|
||||
from typing import Union
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import dbHandler
|
||||
import userHandler
|
||||
import securityHandler
|
||||
|
||||
dbConnection = dbHandler.connect("blorgdb",
|
||||
"172.20.0.10",
|
||||
"dev",
|
||||
"dev",
|
||||
"5432")
|
||||
dbHandler.initTable(dbConnection, "Users", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
Username VARCHAR(255),
|
||||
FirstName VARCHAR(255),
|
||||
LastName VARCHAR(255),
|
||||
Description VARCHAR(255),
|
||||
Country VARCHAR(255),
|
||||
Theme VARCHAR(255),
|
||||
AccentColor VARCHAR(255),
|
||||
PasswordHash VARCHAR(255)
|
||||
""")
|
||||
dbHandler.initTable(dbConnection, "SignOns", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
UserID VARCHAR(255),
|
||||
LoginSuccess BOOLEAN,
|
||||
DateAttempted VARCHAR(255),
|
||||
IPLocationAttempted VARCHAR(255)
|
||||
""")
|
||||
dbHandler.initTable(dbConnection, "AuthTokens", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
Token VARCHAR(255),
|
||||
OwnerID INTEGER,
|
||||
DateCreated TIMESTAMP,
|
||||
DateExpiry TIMESTAMP,
|
||||
IPLocationCreated VARCHAR(255)
|
||||
""")
|
||||
dbHandler.initTable(dbConnection, "Blogs", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
AuthorID INTEGER,
|
||||
CategoryID INTEGER,
|
||||
DatePosted TIMESTAMP,
|
||||
Description VARCHAR(255)
|
||||
""")
|
||||
dbHandler.initTable(dbConnection, "Categories", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
Name VARCHAR(255)
|
||||
""")
|
||||
dbConnection = None
|
||||
|
||||
#dbHandler.insertRow(dbConnection,
|
||||
# 'users',
|
||||
# ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
|
||||
# ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256'])
|
||||
def apiInit():
|
||||
dbConnection = dbHandler.connect("blorgdb",
|
||||
"172.20.0.10",
|
||||
"dev",
|
||||
"dev",
|
||||
"5432")
|
||||
dbHandler.initTable(dbConnection, "Users", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
Username VARCHAR(255),
|
||||
Email VARCHAR(255),
|
||||
FirstName VARCHAR(255),
|
||||
LastName VARCHAR(255),
|
||||
Description VARCHAR(255),
|
||||
Country VARCHAR(255),
|
||||
Theme VARCHAR(255),
|
||||
AccentColor VARCHAR(255),
|
||||
PasswordHash VARCHAR(255)
|
||||
""")
|
||||
dbHandler.initTable(dbConnection, "SignOns", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
UserID INTEGER,
|
||||
LoginSuccess BOOLEAN,
|
||||
DateAttempted TIMESTAMP,
|
||||
IPLocationAttempted VARCHAR(255)
|
||||
""")
|
||||
dbHandler.initTable(dbConnection, "AuthTokens", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
Token VARCHAR(255),
|
||||
OwnerID INTEGER,
|
||||
DateCreated TIMESTAMP,
|
||||
DateExpiry TIMESTAMP,
|
||||
IPLocationCreated VARCHAR(255)
|
||||
""")
|
||||
dbHandler.initTable(dbConnection, "Blogs", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
AuthorID INTEGER,
|
||||
CategoryID INTEGER,
|
||||
DatePosted TIMESTAMP,
|
||||
Description VARCHAR(255)
|
||||
""")
|
||||
dbHandler.initTable(dbConnection, "Categories", """
|
||||
ID SERIAL PRIMARY KEY,
|
||||
Name VARCHAR(255)
|
||||
""")
|
||||
userHandler.createUser(dbConnection, "testuser", "Test", "User", "A test user", "TestCountry", "TestTheme", "TestColor", "testuser")
|
||||
|
||||
print(dbHandler._execQuery(dbConnection, "SELECT * FROM users"))
|
||||
print(dbHandler.getRowByID(dbConnection, "users", 5))
|
||||
print(dbHandler.getFieldByID(dbConnection, "users", 5, "description"))
|
||||
if dbHandler.checkFieldValueExistence(dbConnection, "users", "username", "cspark"):
|
||||
print("It exists!")
|
||||
else:
|
||||
print("It does NOT exist!")
|
||||
def apiCleanup():
|
||||
dbConnection.close()
|
||||
|
||||
dbConnection.close()
|
||||
@asynccontextmanager
|
||||
async def apiLifespan(app: FastAPI):
|
||||
# API Init
|
||||
apiInit()
|
||||
|
||||
# API Clean up
|
||||
yield
|
||||
apiCleanup()
|
||||
|
||||
app = FastAPI()
|
||||
app = FastAPI(lifespan=apiLifespan)
|
||||
|
||||
origins = [
|
||||
"http://localhost",
|
||||
|
|
@ -88,13 +94,31 @@ def getroot():
|
|||
class ApiBody(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@app.post("/api")
|
||||
def postapi(body: ApiBody):
|
||||
print(body.username)
|
||||
print(body.password)
|
||||
return body
|
||||
|
||||
class loginBody(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
rememberMe: bool
|
||||
@app.post("/api/login")
|
||||
def postlogin(body: loginBody, request: Request):
|
||||
try:
|
||||
if userHandler.checkUserExistence(dbConnection, loginBody.username):
|
||||
userID = userHandler.getIDByUsername(dbConnection, loginBody.username)
|
||||
if securityHandler.handlePassword(dbConnection, loginBody.password, userID):
|
||||
return {"success": True, "authToken": tokenHandler.createToken(dbConnection, userID, loginBody.rememberMe, request.client.host), "message": "User login success!"}
|
||||
else:
|
||||
return {"success": False, "authToken": "none", "message": "User login failed! Please check your password."}
|
||||
else:
|
||||
return {"success": False, "authToken": "none", "message": "User login failed! User does not exist."}
|
||||
except:
|
||||
return {"success": False, "authToken": "none", "message": "User login failed! Unexpected server error."}
|
||||
|
||||
@app.get("/api")
|
||||
def getapi():
|
||||
return {"Hello": "API!"}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import argon2
|
|||
import psycopg2
|
||||
|
||||
import dbHandler
|
||||
import userHandler
|
||||
|
||||
debug: bool = True
|
||||
|
||||
|
|
@ -34,17 +35,13 @@ def verifyRehash(hash: str) -> bool:
|
|||
return False
|
||||
|
||||
def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool:
|
||||
hash = dbHandler.getFieldByID(dbConnection, "users", userID, "passwordhash")
|
||||
hash = userHandler.getHashValueByUserID(dbConnection, userID)
|
||||
debugPrint("Now verifying password against hash for user ID " + userid + "...")
|
||||
if verifyPassword(password, hash):
|
||||
debugPrint("(USER ID) " + userID + " Password verification success!")
|
||||
debugPrint("(USER ID " + userID + ") Password verification success!")
|
||||
if verifyRehash(hash):
|
||||
debugPrint("(USER ID) " + userID + " Hash needs to be rehashed! Will now rehash...")
|
||||
debugPrint("(USER ID " + userID + ") Hash needs to be rehashed! Will now rehash...")
|
||||
return True
|
||||
else:
|
||||
debugPrint("(USER ID) " + userID + " Password verification failure!")
|
||||
debugPrint("(USER ID " + userID + ") Password verification failure!")
|
||||
return False
|
||||
|
||||
hashed: str = hashPassword("testing")
|
||||
print(verifyPassword("testing", hashed))
|
||||
print(verifyRehash(hashed))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
|
||||
import dbHandler
|
||||
|
||||
debug: bool = True
|
||||
|
||||
def debugPrint(msg: str) -> None:
|
||||
if debug:
|
||||
print("(TOKEN HANDLER) PRINT: " + msg)
|
||||
|
||||
def createToken(dbConnection: psycopg2.extensions.connection, userID: int, rememberMe: bool, locationIP: str) -> str:
|
||||
debugPrint("Now creating new token with following attributes : userID = " + str(userID) + ", rememberMe = " + str(rememberMe) + ", locationIP = " + locationIP + "...")
|
||||
return "sha"
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
|
||||
import dbHandler
|
||||
import securityHandler
|
||||
|
||||
debug: bool = True
|
||||
|
||||
def debugPrint(msg: str) -> None:
|
||||
if debug:
|
||||
print("(USER HANDLER) PRINT: " + msg)
|
||||
|
||||
def createUser(dbConnection: psycopg2.extensions.connection, username: str, firstName: str, lastName: str, description: str, country: str, theme: str, accentColor: str, password: str):
|
||||
debugPrint("Now creating new user with following attributes : username = " + username + ", firstName = " + firstName + ", lastName = " + lastName + ", description = " + description + ", country = " + country + ", theme = " + theme + ", accentColor = " + accentColor)
|
||||
dbHandler.insertRow(dbConnection,
|
||||
'users',
|
||||
['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
|
||||
[username, firstName, lastName, description, country, theme, accentColor, securityHandler.hashPassword(password)])
|
||||
|
||||
def checkUserExistence(dbConnection: psycopg2.extensions.connection, username: str) -> bool:
|
||||
return dbHandler.checkFieldValueExistence(dbConnection, "users", "username", username)
|
||||
|
||||
def getHashValuebyUserID(dbConnection: psycopg2.extensions.connection, userID: int) -> str:
|
||||
return dbHandler.getFieldValueByID(dbConnection, "users", userID, "passwordhash")
|
||||
|
||||
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
|
||||
debugPrint("Attempting to get ID of username " + username + "...")
|
||||
sanitisedQuery = sql.SQL("""
|
||||
SELECT id FROM users WHERE "username" = {username}
|
||||
""").format(
|
||||
username=sql.Literal(username)
|
||||
)
|
||||
return int(dbHandler._execQuery(dbConnection, sanitisedQuery)[0][0])
|
||||
Loading…
Reference in New Issue