import sys import typing import psycopg2 from psycopg2 import sql debug: bool = True def debugPrint(msg: str) -> None: if debug: print("(DB HANDLER) PRINT: " + msg) def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> None: if debug: print("(DB HANDLER) " + dbConnection.notices[i]) def errorPrint(msg: str) -> None: print("(DB HANDLER) ERROR: " + msg) def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection: debugPrint("Attempting to connect to database...") try: dbConnection = psycopg2.connect(database=databaseOption, host=hostOption, user=userOption, password=passwordOption, port=portOption) debugPrint("Successfully connected to database!") return dbConnection except: errorPrint("Error occurred connecting to database! Exiting...") sys.exit(1) def disconnect(dbConnection: psycopg2.extensions.connection) -> None: debugPrint("Attempting to disconnect database...") try: dbConnection.close() debugPrint("Successfully disconnected database!") except: debugPrint("Failed to disconnect database! Exiting...") sys.exit(1) def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str): dbCursor = dbConnection.cursor() dbCursor.execute(""" DO $$ BEGIN IF (EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '""" + tableName.lower() + """')) THEN RAISE NOTICE '""" + tableName + """ Table already exists! Skipping table creation...'; ELSE RAISE NOTICE '""" + tableName + """ Table does not exist! Creating table...'; CREATE TABLE """ + tableName.lower() + """ ( """ + tableFormat + """ ); END IF; END; $$""") debugPrintNotice(dbConnection, -1) dbConnection.commit() dbCursor.close() # These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks) def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable, fetchResults: bool = True): try: debugPrint("Commit query executing...") dbCursor = dbConnection.cursor() dbCursor.execute(query) dbConnection.commit() if fetchResults: dbResults = dbCursor.fetchall() dbCursor.close() return dbResults else: return True except Exception as error: errorPrint("Commit query failed! Unexpected error: " + repr(error)) return None def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable, fetchResults: bool = True): try: debugPrint("Exec query executing...") dbCursor = dbConnection.cursor() dbCursor.execute(query) if fetchResults: dbResults = dbCursor.fetchall() dbCursor.close() return dbResults else: return True except Exception as error: errorPrint("Exec query failed! Unexpected error: " + repr(error)) return None # Callable helper functions def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list) -> list: debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...") sanitisedQuery = sql.SQL(""" INSERT INTO {table} ({format}) VALUES ({values}) RETURNING *; """).format( table=sql.Identifier(tableName), format=sql.SQL(", ").join( sql.Identifier(value.lower()) for value in tableFormat ), values=sql.SQL(", ").join( sql.Literal(value) for value in tableValues ) ) debugPrint(sanitisedQuery.as_string(dbConnection)) return _commitQuery(dbConnection, sanitisedQuery)[0] def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str: debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...") sanitisedQuery = sql.SQL(""" SELECT {field} FROM {table} WHERE "id" = {id} """).format( table=sql.Identifier(tableName), field=sql.Identifier(tableField), id=sql.Literal(RowID) ) return str(_execQuery(dbConnection, sanitisedQuery)[0][0]) def changeFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str, newValue) -> str: debugPrint("Attempting to change value of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + " to " + str(newValue) + "...") sanitisedQuery = sql.SQL(""" UPDATE {table} SET {field} = {value} WHERE "id" = {id} """).format( table=sql.Identifier(tableName), field=sql.Identifier(tableField), id=sql.Literal(RowID), value=sql.Literal(newValue) ) return _commitQuery(dbConnection, sanitisedQuery, False) def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple: debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...") sanitisedQuery = sql.SQL(""" SELECT * FROM {table} WHERE "id" = {id} """).format( table=sql.Identifier(tableName), id=sql.Literal(RowID) ) return tuple(_execQuery(dbConnection, sanitisedQuery)[0]) def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool: debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...") sanitisedQuery = sql.SQL(""" SELECT EXISTS( SELECT {fieldName} FROM {table} WHERE {fieldName} = {fieldValue} ); """).format( table=sql.Identifier(tableName), fieldName=sql.Identifier(fieldName), fieldValue=sql.Literal(fieldValue) ) return bool(_execQuery(dbConnection, sanitisedQuery)[0][0]) def getRowRangeByID(dbConnection: psycopg2.extensions.connection, tableName: str, rangeStart: int, rangeEnd: int, latestRecords = True) -> tuple: debugPrint("Getting rows from table name " + tableName + " from range " + str(rangeStart) + "-" + str(rangeEnd) + "...") sanitisedQuery = sql.SQL(""" SELECT * FROM {table} WHERE id >= {start} AND id <= {end} ORDER BY id {order} """).format( table=sql.Identifier(tableName), start=sql.Literal(rangeStart), end=sql.Literal(rangeEnd), order=sql.SQL("DESC" if latestRecords else "ASC") ) return tuple(_execQuery(dbConnection, sanitisedQuery))