import typing import psycopg2 from psycopg2 import sql debug: bool = True def debugPrint(msg: str) -> None: if debug: print("(DB HANDLER) PRINT: " + msg) def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> None: if debug: print("(DB HANDLER) " + dbConnection.notices[i]) def errorPrint(msg: str) -> None: print("(DB HANDLER) ERROR: " + msg) def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection: debugPrint("Attempting to connect to database...") try: dbConnection = psycopg2.connect(database=databaseOption, host=hostOption, user=userOption, password=passwordOption, port=portOption) debugPrint("Successfully connected to database!") return dbConnection except: errorPrint("Error occurred connecting to database! Exiting...") sys.exit(1) def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str): dbCursor = dbConnection.cursor() dbCursor.execute(""" DO $$ BEGIN IF (EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '""" + tableName.lower() + """')) THEN RAISE NOTICE '""" + tableName + """ Table does exist! Skipping creating table.'; ELSE RAISE NOTICE '""" + tableName + """ Table does not exist! Creating table.'; CREATE TABLE """ + tableName.lower() + """ ( """ + tableFormat + """ ); END IF; END; $$""") debugPrintNotice(dbConnection, -1) dbConnection.commit() dbCursor.close() # These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks) def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: debugPrint("Commit query executing...") dbCursor = dbConnection.cursor() dbCursor.execute(query) dbConnection.commit() dbResults = dbCursor.fetchall() dbCursor.close() return dbResults def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable) -> list: debugPrint("Exec query executing...") dbCursor = dbConnection.cursor() dbCursor.execute(query) dbResults = dbCursor.fetchall() dbCursor.close() return dbResults # Callable helper functions def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list): debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...") sanitisedQuery = sql.SQL(""" INSERT INTO {table} ({format}) VALUES ({values}) RETURNING *; """).format( table=sql.Identifier(tableName), format=sql.SQL(", ").join( sql.Identifier(value) for value in tableFormat ), values=sql.SQL(", ").join( sql.Literal(value) for value in tableValues ) ) debugPrint(sanitisedQuery.as_string(dbConnection)) _commitQuery(dbConnection, sanitisedQuery) def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str: debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...") sanitisedQuery = sql.SQL(""" SELECT {field} FROM {table} WHERE "id" = {id} """).format( table=sql.Identifier(tableName), field=sql.Identifier(tableField), id=sql.Literal(RowID) ) return str(_execQuery(dbConnection, sanitisedQuery)[0][0]) def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple: debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...") sanitisedQuery = sql.SQL(""" SELECT * FROM {table} WHERE "id" = {id} """).format( table=sql.Identifier(tableName), id=sql.Literal(RowID) ) return tuple(_execQuery(dbConnection, sanitisedQuery)[0]) def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool: debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...") sanitisedQuery = sql.SQL(""" SELECT EXISTS( SELECT {fieldName} FROM {table} WHERE {fieldName} = {fieldValue} ); """).format( table=sql.Identifier(tableName), fieldName=sql.Identifier(fieldName), fieldValue=sql.Literal(fieldValue) ) return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])