Move user specific functions to new userHandler module, setup most of logic for password authentication, init token handling module. Properly setup API init and cleanup functions
This commit is contained in:
parent
8c91612a5c
commit
dabf0c8977
130
dbHandler.py
130
dbHandler.py
|
|
@ -12,6 +12,9 @@ def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> No
|
||||||
if debug:
|
if debug:
|
||||||
print("(DB HANDLER) " + dbConnection.notices[i])
|
print("(DB HANDLER) " + dbConnection.notices[i])
|
||||||
|
|
||||||
|
def errorPrint(msg: str) -> None:
|
||||||
|
print("(DB HANDLER) ERROR: " + msg)
|
||||||
|
|
||||||
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
|
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
|
||||||
debugPrint("Attempting to connect to database...")
|
debugPrint("Attempting to connect to database...")
|
||||||
try:
|
try:
|
||||||
|
|
@ -23,7 +26,7 @@ def connect(databaseOption: str, hostOption: str, userOption: str, passwordOptio
|
||||||
debugPrint("Successfully connected to database!")
|
debugPrint("Successfully connected to database!")
|
||||||
return dbConnection
|
return dbConnection
|
||||||
except:
|
except:
|
||||||
debugPrint("Error occurred connecting to database! Exiting...")
|
errorPrint("Error occurred connecting to database! Exiting...")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
|
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
|
||||||
|
|
@ -50,72 +53,6 @@ def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tabl
|
||||||
dbConnection.commit()
|
dbConnection.commit()
|
||||||
dbCursor.close()
|
dbCursor.close()
|
||||||
|
|
||||||
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list):
|
|
||||||
debugPrint("Attempting to insert new row...")
|
|
||||||
sanitisedQuery = sql.SQL("""
|
|
||||||
INSERT INTO {table} ({format})
|
|
||||||
VALUES ({values})
|
|
||||||
RETURNING *;
|
|
||||||
""").format(
|
|
||||||
table=sql.Identifier(tableName),
|
|
||||||
format=sql.SQL(", ").join(
|
|
||||||
sql.Identifier(value) for value in tableFormat
|
|
||||||
),
|
|
||||||
values=sql.SQL(", ").join(
|
|
||||||
sql.Literal(value) for value in tableValues
|
|
||||||
)
|
|
||||||
)
|
|
||||||
debugPrint(sanitisedQuery.as_string(dbConnection))
|
|
||||||
_commitQuery(dbConnection, sanitisedQuery)
|
|
||||||
|
|
||||||
def getFieldByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
|
|
||||||
debugPrint("Attempting to get field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
|
|
||||||
sanitisedQuery = sql.SQL("""
|
|
||||||
SELECT {field} FROM {table} WHERE "id" = {id}
|
|
||||||
""").format(
|
|
||||||
table=sql.Identifier(tableName),
|
|
||||||
field=sql.Identifier(tableField),
|
|
||||||
id=sql.Literal(RowID)
|
|
||||||
)
|
|
||||||
return _execQuery(dbConnection, sanitisedQuery)[0][0]
|
|
||||||
|
|
||||||
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
|
|
||||||
debugPrint("Attempting to get row by ID " + str(RowID) + " in table name " + tableName + "...")
|
|
||||||
sanitisedQuery = sql.SQL("""
|
|
||||||
SELECT * FROM {table} WHERE "id" = {id}
|
|
||||||
""").format(
|
|
||||||
table=sql.Identifier(tableName),
|
|
||||||
id=sql.Literal(RowID)
|
|
||||||
)
|
|
||||||
return _execQuery(dbConnection, sanitisedQuery)[0]
|
|
||||||
|
|
||||||
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
|
|
||||||
debugPrint("Attempting to get ID by username " + username + "...")
|
|
||||||
sanitisedQuery = sql.SQL("""
|
|
||||||
SELECT id FROM users WHERE "username" = {username}
|
|
||||||
""").format(
|
|
||||||
username=sql.Literal(username)
|
|
||||||
)
|
|
||||||
return _execQuery(dbConnection, sanitisedQuery)[0][0]
|
|
||||||
|
|
||||||
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
|
|
||||||
debugPrint("Checking if field name " + fieldName + " in " + tableName + " with value " + str(fieldValue) + " exists...")
|
|
||||||
sanitisedQuery = sql.SQL("""
|
|
||||||
SELECT EXISTS(
|
|
||||||
SELECT
|
|
||||||
{fieldName}
|
|
||||||
FROM
|
|
||||||
{table}
|
|
||||||
WHERE
|
|
||||||
{fieldName} = {fieldValue}
|
|
||||||
);
|
|
||||||
""").format(
|
|
||||||
table=sql.Identifier(tableName),
|
|
||||||
fieldName=sql.Identifier(fieldName),
|
|
||||||
fieldValue=sql.Literal(fieldValue)
|
|
||||||
)
|
|
||||||
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
|
||||||
|
|
||||||
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
|
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
|
||||||
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
|
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list:
|
||||||
debugPrint("Commit query executing...")
|
debugPrint("Commit query executing...")
|
||||||
|
|
@ -132,3 +69,62 @@ def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composab
|
||||||
dbResults = dbCursor.fetchall()
|
dbResults = dbCursor.fetchall()
|
||||||
dbCursor.close()
|
dbCursor.close()
|
||||||
return dbResults
|
return dbResults
|
||||||
|
|
||||||
|
# Callable helper functions
|
||||||
|
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list):
|
||||||
|
debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
INSERT INTO {table} ({format})
|
||||||
|
VALUES ({values})
|
||||||
|
RETURNING *;
|
||||||
|
""").format(
|
||||||
|
table=sql.Identifier(tableName),
|
||||||
|
format=sql.SQL(", ").join(
|
||||||
|
sql.Identifier(value) for value in tableFormat
|
||||||
|
),
|
||||||
|
values=sql.SQL(", ").join(
|
||||||
|
sql.Literal(value) for value in tableValues
|
||||||
|
)
|
||||||
|
)
|
||||||
|
debugPrint(sanitisedQuery.as_string(dbConnection))
|
||||||
|
_commitQuery(dbConnection, sanitisedQuery)
|
||||||
|
|
||||||
|
def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
|
||||||
|
debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
SELECT {field} FROM {table} WHERE "id" = {id}
|
||||||
|
""").format(
|
||||||
|
table=sql.Identifier(tableName),
|
||||||
|
field=sql.Identifier(tableField),
|
||||||
|
id=sql.Literal(RowID)
|
||||||
|
)
|
||||||
|
return str(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
||||||
|
|
||||||
|
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
|
||||||
|
debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
SELECT * FROM {table} WHERE "id" = {id}
|
||||||
|
""").format(
|
||||||
|
table=sql.Identifier(tableName),
|
||||||
|
id=sql.Literal(RowID)
|
||||||
|
)
|
||||||
|
return tuple(_execQuery(dbConnection, sanitisedQuery)[0])
|
||||||
|
|
||||||
|
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
|
||||||
|
debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT
|
||||||
|
{fieldName}
|
||||||
|
FROM
|
||||||
|
{table}
|
||||||
|
WHERE
|
||||||
|
{fieldName} = {fieldValue}
|
||||||
|
);
|
||||||
|
""").format(
|
||||||
|
table=sql.Identifier(tableName),
|
||||||
|
fieldName=sql.Identifier(fieldName),
|
||||||
|
fieldValue=sql.Literal(fieldValue)
|
||||||
|
)
|
||||||
|
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
||||||
|
|
||||||
|
|
|
||||||
130
main.py
130
main.py
|
|
@ -1,72 +1,78 @@
|
||||||
import sys
|
import sys
|
||||||
|
import atexit
|
||||||
|
import signal
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import dbHandler
|
import dbHandler
|
||||||
|
import userHandler
|
||||||
|
import securityHandler
|
||||||
|
|
||||||
dbConnection = dbHandler.connect("blorgdb",
|
dbConnection = None
|
||||||
|
|
||||||
|
def apiInit():
|
||||||
|
dbConnection = dbHandler.connect("blorgdb",
|
||||||
"172.20.0.10",
|
"172.20.0.10",
|
||||||
"dev",
|
"dev",
|
||||||
"dev",
|
"dev",
|
||||||
"5432")
|
"5432")
|
||||||
dbHandler.initTable(dbConnection, "Users", """
|
dbHandler.initTable(dbConnection, "Users", """
|
||||||
ID SERIAL PRIMARY KEY,
|
ID SERIAL PRIMARY KEY,
|
||||||
Username VARCHAR(255),
|
Username VARCHAR(255),
|
||||||
FirstName VARCHAR(255),
|
Email VARCHAR(255),
|
||||||
LastName VARCHAR(255),
|
FirstName VARCHAR(255),
|
||||||
Description VARCHAR(255),
|
LastName VARCHAR(255),
|
||||||
Country VARCHAR(255),
|
Description VARCHAR(255),
|
||||||
Theme VARCHAR(255),
|
Country VARCHAR(255),
|
||||||
AccentColor VARCHAR(255),
|
Theme VARCHAR(255),
|
||||||
PasswordHash VARCHAR(255)
|
AccentColor VARCHAR(255),
|
||||||
""")
|
PasswordHash VARCHAR(255)
|
||||||
dbHandler.initTable(dbConnection, "SignOns", """
|
""")
|
||||||
ID SERIAL PRIMARY KEY,
|
dbHandler.initTable(dbConnection, "SignOns", """
|
||||||
UserID VARCHAR(255),
|
ID SERIAL PRIMARY KEY,
|
||||||
LoginSuccess BOOLEAN,
|
UserID INTEGER,
|
||||||
DateAttempted VARCHAR(255),
|
LoginSuccess BOOLEAN,
|
||||||
IPLocationAttempted VARCHAR(255)
|
DateAttempted TIMESTAMP,
|
||||||
""")
|
IPLocationAttempted VARCHAR(255)
|
||||||
dbHandler.initTable(dbConnection, "AuthTokens", """
|
""")
|
||||||
ID SERIAL PRIMARY KEY,
|
dbHandler.initTable(dbConnection, "AuthTokens", """
|
||||||
Token VARCHAR(255),
|
ID SERIAL PRIMARY KEY,
|
||||||
OwnerID INTEGER,
|
Token VARCHAR(255),
|
||||||
DateCreated TIMESTAMP,
|
OwnerID INTEGER,
|
||||||
DateExpiry TIMESTAMP,
|
DateCreated TIMESTAMP,
|
||||||
IPLocationCreated VARCHAR(255)
|
DateExpiry TIMESTAMP,
|
||||||
""")
|
IPLocationCreated VARCHAR(255)
|
||||||
dbHandler.initTable(dbConnection, "Blogs", """
|
""")
|
||||||
ID SERIAL PRIMARY KEY,
|
dbHandler.initTable(dbConnection, "Blogs", """
|
||||||
AuthorID INTEGER,
|
ID SERIAL PRIMARY KEY,
|
||||||
CategoryID INTEGER,
|
AuthorID INTEGER,
|
||||||
DatePosted TIMESTAMP,
|
CategoryID INTEGER,
|
||||||
Description VARCHAR(255)
|
DatePosted TIMESTAMP,
|
||||||
""")
|
Description VARCHAR(255)
|
||||||
dbHandler.initTable(dbConnection, "Categories", """
|
""")
|
||||||
ID SERIAL PRIMARY KEY,
|
dbHandler.initTable(dbConnection, "Categories", """
|
||||||
Name VARCHAR(255)
|
ID SERIAL PRIMARY KEY,
|
||||||
""")
|
Name VARCHAR(255)
|
||||||
|
""")
|
||||||
|
userHandler.createUser(dbConnection, "testuser", "Test", "User", "A test user", "TestCountry", "TestTheme", "TestColor", "testuser")
|
||||||
|
|
||||||
#dbHandler.insertRow(dbConnection,
|
def apiCleanup():
|
||||||
# 'users',
|
dbConnection.close()
|
||||||
# ['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
|
|
||||||
# ['cspark', 'Curt', 'Spark', 'A short description', 'United Kingdom', 'light', 'purple', 'hash256'])
|
|
||||||
|
|
||||||
print(dbHandler._execQuery(dbConnection, "SELECT * FROM users"))
|
@asynccontextmanager
|
||||||
print(dbHandler.getRowByID(dbConnection, "users", 5))
|
async def apiLifespan(app: FastAPI):
|
||||||
print(dbHandler.getFieldByID(dbConnection, "users", 5, "description"))
|
# API Init
|
||||||
if dbHandler.checkFieldValueExistence(dbConnection, "users", "username", "cspark"):
|
apiInit()
|
||||||
print("It exists!")
|
|
||||||
else:
|
|
||||||
print("It does NOT exist!")
|
|
||||||
|
|
||||||
dbConnection.close()
|
# API Clean up
|
||||||
|
yield
|
||||||
|
apiCleanup()
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI(lifespan=apiLifespan)
|
||||||
|
|
||||||
origins = [
|
origins = [
|
||||||
"http://localhost",
|
"http://localhost",
|
||||||
|
|
@ -88,13 +94,31 @@ def getroot():
|
||||||
class ApiBody(BaseModel):
|
class ApiBody(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
@app.post("/api")
|
@app.post("/api")
|
||||||
def postapi(body: ApiBody):
|
def postapi(body: ApiBody):
|
||||||
print(body.username)
|
print(body.username)
|
||||||
print(body.password)
|
print(body.password)
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
class loginBody(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
rememberMe: bool
|
||||||
|
@app.post("/api/login")
|
||||||
|
def postlogin(body: loginBody, request: Request):
|
||||||
|
try:
|
||||||
|
if userHandler.checkUserExistence(dbConnection, loginBody.username):
|
||||||
|
userID = userHandler.getIDByUsername(dbConnection, loginBody.username)
|
||||||
|
if securityHandler.handlePassword(dbConnection, loginBody.password, userID):
|
||||||
|
return {"success": True, "authToken": tokenHandler.createToken(dbConnection, userID, loginBody.rememberMe, request.client.host), "message": "User login success!"}
|
||||||
|
else:
|
||||||
|
return {"success": False, "authToken": "none", "message": "User login failed! Please check your password."}
|
||||||
|
else:
|
||||||
|
return {"success": False, "authToken": "none", "message": "User login failed! User does not exist."}
|
||||||
|
except:
|
||||||
|
return {"success": False, "authToken": "none", "message": "User login failed! Unexpected server error."}
|
||||||
|
|
||||||
@app.get("/api")
|
@app.get("/api")
|
||||||
def getapi():
|
def getapi():
|
||||||
return {"Hello": "API!"}
|
return {"Hello": "API!"}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import argon2
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
|
||||||
import dbHandler
|
import dbHandler
|
||||||
|
import userHandler
|
||||||
|
|
||||||
debug: bool = True
|
debug: bool = True
|
||||||
|
|
||||||
|
|
@ -34,17 +35,13 @@ def verifyRehash(hash: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool:
|
def handlePassword(dbConnection: psycopg2.extensions.connection, password: str, userID: int) -> bool:
|
||||||
hash = dbHandler.getFieldByID(dbConnection, "users", userID, "passwordhash")
|
hash = userHandler.getHashValueByUserID(dbConnection, userID)
|
||||||
debugPrint("Now verifying password against hash for user ID " + userid + "...")
|
debugPrint("Now verifying password against hash for user ID " + userid + "...")
|
||||||
if verifyPassword(password, hash):
|
if verifyPassword(password, hash):
|
||||||
debugPrint("(USER ID) " + userID + " Password verification success!")
|
debugPrint("(USER ID " + userID + ") Password verification success!")
|
||||||
if verifyRehash(hash):
|
if verifyRehash(hash):
|
||||||
debugPrint("(USER ID) " + userID + " Hash needs to be rehashed! Will now rehash...")
|
debugPrint("(USER ID " + userID + ") Hash needs to be rehashed! Will now rehash...")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
debugPrint("(USER ID) " + userID + " Password verification failure!")
|
debugPrint("(USER ID " + userID + ") Password verification failure!")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
hashed: str = hashPassword("testing")
|
|
||||||
print(verifyPassword("testing", hashed))
|
|
||||||
print(verifyRehash(hashed))
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2 import sql
|
||||||
|
|
||||||
|
import dbHandler
|
||||||
|
|
||||||
|
debug: bool = True
|
||||||
|
|
||||||
|
def debugPrint(msg: str) -> None:
|
||||||
|
if debug:
|
||||||
|
print("(TOKEN HANDLER) PRINT: " + msg)
|
||||||
|
|
||||||
|
def createToken(dbConnection: psycopg2.extensions.connection, userID: int, rememberMe: bool, locationIP: str) -> str:
|
||||||
|
debugPrint("Now creating new token with following attributes : userID = " + str(userID) + ", rememberMe = " + str(rememberMe) + ", locationIP = " + locationIP + "...")
|
||||||
|
return "sha"
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2 import sql
|
||||||
|
|
||||||
|
import dbHandler
|
||||||
|
import securityHandler
|
||||||
|
|
||||||
|
debug: bool = True
|
||||||
|
|
||||||
|
def debugPrint(msg: str) -> None:
|
||||||
|
if debug:
|
||||||
|
print("(USER HANDLER) PRINT: " + msg)
|
||||||
|
|
||||||
|
def createUser(dbConnection: psycopg2.extensions.connection, username: str, firstName: str, lastName: str, description: str, country: str, theme: str, accentColor: str, password: str):
|
||||||
|
debugPrint("Now creating new user with following attributes : username = " + username + ", firstName = " + firstName + ", lastName = " + lastName + ", description = " + description + ", country = " + country + ", theme = " + theme + ", accentColor = " + accentColor)
|
||||||
|
dbHandler.insertRow(dbConnection,
|
||||||
|
'users',
|
||||||
|
['username', 'firstname', 'lastname', 'description', 'country', 'theme', 'accentcolor', 'passwordhash'],
|
||||||
|
[username, firstName, lastName, description, country, theme, accentColor, securityHandler.hashPassword(password)])
|
||||||
|
|
||||||
|
def checkUserExistence(dbConnection: psycopg2.extensions.connection, username: str) -> bool:
|
||||||
|
return dbHandler.checkFieldValueExistence(dbConnection, "users", "username", username)
|
||||||
|
|
||||||
|
def getHashValuebyUserID(dbConnection: psycopg2.extensions.connection, userID: int) -> str:
|
||||||
|
return dbHandler.getFieldValueByID(dbConnection, "users", userID, "passwordhash")
|
||||||
|
|
||||||
|
def getIDByUsername(dbConnection: psycopg2.extensions.connection, username: str) -> int:
|
||||||
|
debugPrint("Attempting to get ID of username " + username + "...")
|
||||||
|
sanitisedQuery = sql.SQL("""
|
||||||
|
SELECT id FROM users WHERE "username" = {username}
|
||||||
|
""").format(
|
||||||
|
username=sql.Literal(username)
|
||||||
|
)
|
||||||
|
return int(dbHandler._execQuery(dbConnection, sanitisedQuery)[0][0])
|
||||||
Loading…
Reference in New Issue