Blorg-Backend/dbHandler.py

178 lines
7.0 KiB
Python

import sys
import typing
import psycopg2
from psycopg2 import sql
debug: bool = True
def debugPrint(msg: str) -> None:
if debug:
print("(DB HANDLER) PRINT: " + msg)
def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> None:
if debug:
print("(DB HANDLER) " + dbConnection.notices[i])
def errorPrint(msg: str) -> None:
print("(DB HANDLER) ERROR: " + msg)
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
debugPrint("Attempting to connect to database...")
try:
dbConnection = psycopg2.connect(database=databaseOption,
host=hostOption,
user=userOption,
password=passwordOption,
port=portOption)
debugPrint("Successfully connected to database!")
return dbConnection
except:
errorPrint("Error occurred connecting to database! Exiting...")
sys.exit(1)
def disconnect(dbConnection: psycopg2.extensions.connection) -> None:
debugPrint("Attempting to disconnect database...")
try:
dbConnection.close()
debugPrint("Successfully disconnected database!")
except:
debugPrint("Failed to disconnect database! Exiting...")
sys.exit(1)
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
dbCursor = dbConnection.cursor()
dbCursor.execute("""
DO $$
BEGIN
IF (EXISTS (SELECT *
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_NAME = '""" + tableName.lower() + """'))
THEN
RAISE NOTICE '""" + tableName + """ Table already exists! Skipping table creation...';
ELSE
RAISE NOTICE '""" + tableName + """ Table does not exist! Creating table...';
CREATE TABLE """ + tableName.lower() + """ (
""" + tableFormat + """
);
END IF;
END;
$$""")
debugPrintNotice(dbConnection, -1)
dbConnection.commit()
dbCursor.close()
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable, fetchResults: bool = True):
try:
debugPrint("Commit query executing...")
dbCursor = dbConnection.cursor()
dbCursor.execute(query)
dbConnection.commit()
if fetchResults:
dbResults = dbCursor.fetchall()
dbCursor.close()
return dbResults
else:
return True
except Exception as error:
errorPrint("Commit query failed! Unexpected error: " + repr(error))
return None
def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable, fetchResults: bool = True):
try:
debugPrint("Exec query executing...")
dbCursor = dbConnection.cursor()
dbCursor.execute(query)
if fetchResults:
dbResults = dbCursor.fetchall()
dbCursor.close()
return dbResults
else:
return True
except Exception as error:
errorPrint("Exec query failed! Unexpected error: " + repr(error))
return None
# Callable helper functions
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list) -> list:
debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
INSERT INTO {table} ({format})
VALUES ({values})
RETURNING *;
""").format(
table=sql.Identifier(tableName),
format=sql.SQL(", ").join(
sql.Identifier(value.lower()) for value in tableFormat
),
values=sql.SQL(", ").join(
sql.Literal(value) for value in tableValues
)
)
debugPrint(sanitisedQuery.as_string(dbConnection))
return _commitQuery(dbConnection, sanitisedQuery)[0]
def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT {field} FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
field=sql.Identifier(tableField),
id=sql.Literal(RowID)
)
return str(_execQuery(dbConnection, sanitisedQuery)[0][0])
def changeFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str, newValue) -> str:
debugPrint("Attempting to change value of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + " to " + str(newValue) + "...")
sanitisedQuery = sql.SQL("""
UPDATE {table} SET {field} = {value} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
field=sql.Identifier(tableField),
id=sql.Literal(RowID),
value=sql.Literal(newValue)
)
return _commitQuery(dbConnection, sanitisedQuery, False)
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...")
sanitisedQuery = sql.SQL("""
SELECT * FROM {table} WHERE "id" = {id}
""").format(
table=sql.Identifier(tableName),
id=sql.Literal(RowID)
)
return tuple(_execQuery(dbConnection, sanitisedQuery)[0])
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...")
sanitisedQuery = sql.SQL("""
SELECT EXISTS(
SELECT
{fieldName}
FROM
{table}
WHERE
{fieldName} = {fieldValue}
);
""").format(
table=sql.Identifier(tableName),
fieldName=sql.Identifier(fieldName),
fieldValue=sql.Literal(fieldValue)
)
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
def getRowRangeByID(dbConnection: psycopg2.extensions.connection, tableName: str, rangeStart: int, rangeEnd: int, latestRecords = True) -> tuple:
debugPrint("Getting rows from table name " + tableName + " from range " + str(rangeStart) + "-" + str(rangeEnd) + "...")
sanitisedQuery = sql.SQL("""
SELECT * FROM {table} WHERE id >= {start} AND id <= {end} ORDER BY id {order}
""").format(
table=sql.Identifier(tableName),
start=sql.Literal(rangeStart),
end=sql.Literal(rangeEnd),
order=sql.SQL("DESC" if latestRecords else "ASC")
)
return tuple(_execQuery(dbConnection, sanitisedQuery))