act-cluster-lib/src/db_client.cpp

219 lines
6.5 KiB
C++

/*************************************************************************
*
* Copyright (c) 2023 Fabian Posch
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor,
* Boston, MA 02110-1301, USA.
*
**************************************************************************
*/
#include <iostream>
#include "util.h"
#include "db_client.hpp"
namespace db {
using namespace std;
string hostname;
bool connected = false;
Connection::Connection(
db_credentials_t& credentials,
std::function<void(pqxx::connection *c)> setup
) : setup_function(move(setup)) {
this->db_credentials = credentials;
}
Connection::Connection(db_credentials_t& credentials) {
this->db_credentials = credentials;
}
Connection::~Connection() {
if (c != nullptr && c->is_open()) c->close();
}
bool Connection::connect() {
bool connection_established = false;
for (int i = 0; i < MAX_CON_RETRIES; i++) {
try {
// create the connection object
this->c = new pqxx::connection(
"host=" + db_credentials.server + " "
"port=" + std::to_string(db_credentials.port) + " "
"user=" + db_credentials.uname + " "
"password=" + db_credentials.pwd + " "
"dbname=" + db_credentials.dbase
);
// make sure the database actually has the version we need
pqxx::work txn(*(this->c));
auto db_info = txn.exec1("SELECT * FROM info;");
txn.commit();
if (db_info["db_version"].as<int>() != this->db_credentials.version) {
this->disconnect();
cerr << "Error: Unsupported database version! Command expects ";
cerr << this->db_credentials.version;
cerr << ", server provides " << db_info["db_version"].as<int>() << "!" << endl;
return false;
}
if (this->setup_function != nullptr) {
// execute the initialization function
this->setup_function(c);
}
connection_established = true;
} catch (const exception &e) {
cerr << "Error: Could not connect to database:" << endl;
cerr << e.what() << endl;
}
}
return connection_established;
}
bool Connection::ensure_connected() {
if (c == nullptr || !c->is_open()) {
c->close();
return connect();
}
return true;
}
void Connection::disconnect() {
if (c != nullptr && c->is_open()) c->close();
}
bool Connection::prepare_statements(vector<pair<string, string>> statements) {
try {
for (auto statement : statements) {
this->c->prepare(statement.first, statement.second);
}
} catch (const exception &e) {
cerr << "Error: Could not prepare statements!" << endl;
return false;
}
return true;
}
bool Connection::prepare_statement(std::string name, std::string statement) {
try {
this->c->prepare(name, statement);
} catch (const exception &e) {
cerr << "Error: Could not prepare statements!" << endl;
return false;
}
return true;
}
bool Connection::unprepare_statement(std::string name) {
auto unprepare_statement_lambda = [] (pqxx::work* txn, std::string statement) {
txn->exec0("DEALLOCATE " + statement + ";");
};
std::function<void(pqxx::work*, std::string)> unprepare_statement_func = unprepare_statement_lambda;
return this->send_request(&unprepare_statement_func, name);
}
int Connection::find_job(std::string p_name, std::string *f_name) {
auto check_job_ids_lambda = [](pqxx::work *txn, int *ret, std::string p_name, std::string *f_name) {
pqxx::result res {txn->exec("SELECT id FROM jobs;")};
*ret = 0;
for (auto row : res) {
// check if the ID starts with the partial name we've been given
if (row["id"].as<std::string>().rfind(p_name, 0) == 0) {
// make sure we haven't found it before
if (*ret == 0) {
*ret = 1;
*f_name = row["id"].as<std::string>();
} else {
// guess the partial name is ambiguous
*ret = -1;
// we've already seen two, we don't need more
return;
}
}
}
};
std::function<void(pqxx::work*, int*, std::string, std::string*)> check_job_ids_func = check_job_ids_lambda;
DEBUG_PRINT("Sending request...");
int ret;
if (!this->send_request(&check_job_ids_func, &ret, p_name, f_name)) {
std::cerr << "Error: Could not fetch job IDs from database!" << std::endl;
return 1;
}
DEBUG_PRINT("Request complete.");
return ret;
}
JobStatusType Connection::get_job_status(std::string job) {
auto get_jstatus_lambda = [](pqxx::work *txn, std::string job, JobStatusType *status) {
try {
pqxx::row res {txn->exec1("SELECT job_status FROM jobs WHERE id='" + job + "';")};
*status = res["job_status"].as<JobStatusType>();
} catch (pqxx::unexpected_rows& e) {
std::cerr << "Error: Fetching job returned nothing or too many rows!" << std::endl;
*status = JobStatusType::UNKNOWN;
}
};
std::function<void(pqxx::work*, std::string, JobStatusType*)> get_jstatus_func = get_jstatus_lambda;
DEBUG_PRINT("Sending request...");
JobStatusType status;
if (!this->send_request(&get_jstatus_func, job, &status)) {
std::cerr << "Error: Could not fetch job status from database!" << std::endl;
return JobStatusType::UNKNOWN;
}
DEBUG_PRINT("Request complete.");
return status;
}
JobStatusType Connection::get_task_status(db::uuid_t) {
std::cout << "JOB STATUS called, implement me pls" << std::endl;
return JobStatusType::IN_PROGRESS;
}
}