178 lines
7.0 KiB
Python
178 lines
7.0 KiB
Python
import sys
|
|
import typing
|
|
import psycopg2
|
|
from psycopg2 import sql
|
|
|
|
debug: bool = True
|
|
|
|
def debugPrint(msg: str) -> None:
|
|
if debug:
|
|
print("(DB HANDLER) PRINT: " + msg)
|
|
|
|
def debugPrintNotice(dbConnection: psycopg2.extensions.connection, i: int) -> None:
|
|
if debug:
|
|
print("(DB HANDLER) " + dbConnection.notices[i])
|
|
|
|
def errorPrint(msg: str) -> None:
|
|
print("(DB HANDLER) ERROR: " + msg)
|
|
|
|
def connect(databaseOption: str, hostOption: str, userOption: str, passwordOption: str, portOption: str) -> psycopg2.extensions.connection:
|
|
debugPrint("Attempting to connect to database...")
|
|
try:
|
|
dbConnection = psycopg2.connect(database=databaseOption,
|
|
host=hostOption,
|
|
user=userOption,
|
|
password=passwordOption,
|
|
port=portOption)
|
|
debugPrint("Successfully connected to database!")
|
|
return dbConnection
|
|
except:
|
|
errorPrint("Error occurred connecting to database! Exiting...")
|
|
sys.exit(1)
|
|
|
|
def disconnect(dbConnection: psycopg2.extensions.connection) -> None:
|
|
debugPrint("Attempting to disconnect database...")
|
|
try:
|
|
dbConnection.close()
|
|
debugPrint("Successfully disconnected database!")
|
|
except:
|
|
debugPrint("Failed to disconnect database! Exiting...")
|
|
sys.exit(1)
|
|
|
|
def initTable(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: str):
|
|
dbCursor = dbConnection.cursor()
|
|
|
|
dbCursor.execute("""
|
|
DO $$
|
|
BEGIN
|
|
IF (EXISTS (SELECT *
|
|
FROM INFORMATION_SCHEMA.TABLES
|
|
WHERE TABLE_NAME = '""" + tableName.lower() + """'))
|
|
THEN
|
|
RAISE NOTICE '""" + tableName + """ Table already exists! Skipping table creation...';
|
|
ELSE
|
|
RAISE NOTICE '""" + tableName + """ Table does not exist! Creating table...';
|
|
CREATE TABLE """ + tableName.lower() + """ (
|
|
""" + tableFormat + """
|
|
);
|
|
END IF;
|
|
END;
|
|
$$""")
|
|
debugPrintNotice(dbConnection, -1)
|
|
|
|
dbConnection.commit()
|
|
dbCursor.close()
|
|
|
|
# These base functions should not be called directly as they perform no string query sanitisation (Therefore vulnerable to SQL injection attacks)
|
|
def _commitQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable, fetchResults: bool = True):
|
|
try:
|
|
debugPrint("Commit query executing...")
|
|
dbCursor = dbConnection.cursor()
|
|
dbCursor.execute(query)
|
|
dbConnection.commit()
|
|
if fetchResults:
|
|
dbResults = dbCursor.fetchall()
|
|
dbCursor.close()
|
|
return dbResults
|
|
else:
|
|
return True
|
|
except Exception as error:
|
|
errorPrint("Commit query failed! Unexpected error: " + repr(error))
|
|
return None
|
|
def _execQuery(dbConnection: psycopg2.extensions.connection, query: sql.Composable, fetchResults: bool = True):
|
|
try:
|
|
debugPrint("Exec query executing...")
|
|
dbCursor = dbConnection.cursor()
|
|
dbCursor.execute(query)
|
|
if fetchResults:
|
|
dbResults = dbCursor.fetchall()
|
|
dbCursor.close()
|
|
return dbResults
|
|
else:
|
|
return True
|
|
except Exception as error:
|
|
errorPrint("Exec query failed! Unexpected error: " + repr(error))
|
|
return None
|
|
|
|
# Callable helper functions
|
|
def insertRow(dbConnection: psycopg2.extensions.connection, tableName: str, tableFormat: list[str], tableValues: list) -> list:
|
|
debugPrint("Attempting to insert new row (" + str(tableFormat) + ") VALUES (" + str(tableValues) + ") into table name " + tableName + "...")
|
|
sanitisedQuery = sql.SQL("""
|
|
INSERT INTO {table} ({format})
|
|
VALUES ({values})
|
|
RETURNING *;
|
|
""").format(
|
|
table=sql.Identifier(tableName),
|
|
format=sql.SQL(", ").join(
|
|
sql.Identifier(value.lower()) for value in tableFormat
|
|
),
|
|
values=sql.SQL(", ").join(
|
|
sql.Literal(value) for value in tableValues
|
|
)
|
|
)
|
|
debugPrint(sanitisedQuery.as_string(dbConnection))
|
|
return _commitQuery(dbConnection, sanitisedQuery)[0]
|
|
|
|
def getFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str) -> str:
|
|
debugPrint("Attempting to get values of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + "...")
|
|
sanitisedQuery = sql.SQL("""
|
|
SELECT {field} FROM {table} WHERE "id" = {id}
|
|
""").format(
|
|
table=sql.Identifier(tableName),
|
|
field=sql.Identifier(tableField),
|
|
id=sql.Literal(RowID)
|
|
)
|
|
return str(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
|
|
|
def changeFieldValueByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int, tableField: str, newValue) -> str:
|
|
debugPrint("Attempting to change value of field name " + tableField + " in ID row " + str(RowID) + " in table name " + tableName + " to " + str(newValue) + "...")
|
|
sanitisedQuery = sql.SQL("""
|
|
UPDATE {table} SET {field} = {value} WHERE "id" = {id}
|
|
""").format(
|
|
table=sql.Identifier(tableName),
|
|
field=sql.Identifier(tableField),
|
|
id=sql.Literal(RowID),
|
|
value=sql.Literal(newValue)
|
|
)
|
|
return _commitQuery(dbConnection, sanitisedQuery, False)
|
|
|
|
def getRowByID(dbConnection: psycopg2.extensions.connection, tableName: str, RowID: int) -> tuple:
|
|
debugPrint("Attempting to get row values by ID " + str(RowID) + " in table name " + tableName + "...")
|
|
sanitisedQuery = sql.SQL("""
|
|
SELECT * FROM {table} WHERE "id" = {id}
|
|
""").format(
|
|
table=sql.Identifier(tableName),
|
|
id=sql.Literal(RowID)
|
|
)
|
|
return tuple(_execQuery(dbConnection, sanitisedQuery)[0])
|
|
|
|
def checkFieldValueExistence(dbConnection: psycopg2.extensions.connection, tableName: str, fieldName: str, fieldValue) -> bool:
|
|
debugPrint("Checking if field name " + fieldName + " in table " + tableName + " with value " + str(fieldValue) + " exists...")
|
|
sanitisedQuery = sql.SQL("""
|
|
SELECT EXISTS(
|
|
SELECT
|
|
{fieldName}
|
|
FROM
|
|
{table}
|
|
WHERE
|
|
{fieldName} = {fieldValue}
|
|
);
|
|
""").format(
|
|
table=sql.Identifier(tableName),
|
|
fieldName=sql.Identifier(fieldName),
|
|
fieldValue=sql.Literal(fieldValue)
|
|
)
|
|
return bool(_execQuery(dbConnection, sanitisedQuery)[0][0])
|
|
|
|
def getRowRangeByID(dbConnection: psycopg2.extensions.connection, tableName: str, rangeStart: int, rangeEnd: int, latestRecords = True) -> tuple:
|
|
debugPrint("Getting rows from table name " + tableName + " from range " + str(rangeStart) + "-" + str(rangeEnd) + "...")
|
|
sanitisedQuery = sql.SQL("""
|
|
SELECT * FROM {table} WHERE id >= {start} AND id <= {end} ORDER BY id {order}
|
|
""").format(
|
|
table=sql.Identifier(tableName),
|
|
start=sql.Literal(rangeStart),
|
|
end=sql.Literal(rangeEnd),
|
|
order=sql.SQL("DESC" if latestRecords else "ASC")
|
|
)
|
|
return tuple(_execQuery(dbConnection, sanitisedQuery))
|