/* * db.h * * See the README file for copyright information and how to reach the author. * */ #ifndef __DB_H #define __DB_H #include #include #include #include #include #include #include #include #include #include #include "common.h" class cDbTable; class cDbConnection; using namespace std; //*************************************************************************** // cDbService //*************************************************************************** class cDbService { public: enum Misc { maxIndexFields = 20 }; enum FieldFormat { ffUnknown = na, ffInt, ffUInt, ffAscii, // -> VARCHAR ffText, ffMlob, // -> MEDIUMBLOB ffFloat, ffDateTime, ffCount }; enum FieldType { ftUnknown = na, ftData = 1, ftPrimary = 2, ftMeta = 4, ftCalc = 8, ftAutoinc = 16, ftDef0 = 32 }; struct FieldDef { const char* name; FieldFormat format; int size; int index; int type; }; enum BindType { bndIn = 0x001, bndOut = 0x002, bndSet = 0x004 }; enum ProcType { ptProcedure, ptFunction }; struct IndexDef { const char* name; int fields[maxIndexFields+1]; int order; // not implemented yet }; static const char* toString(FieldFormat t); static FieldFormat toDictFormat(const char* format); static const char* formats[]; static const char* dictFormats[]; static FieldType toType(const char* type); static const char* types[]; }; typedef cDbService cDBS; //*************************************************************************** // cDbValue //*************************************************************************** class cDbValue : public cDbService { public: cDbValue(FieldDef* f = 0) { field = 0; strValue = 0; ownField = 0; if (f) setField(f); } cDbValue(const char* name, FieldFormat format, int size) { strValue = 0; ownField = new FieldDef; ownField->name = strdup(name); ownField->format = format; ownField->size = size; ownField->type = ftData; field = ownField; clear(); } virtual ~cDbValue() { free(); } void free() { clear(); ::free(strValue); strValue = 0; if (ownField) { ::free((char*)ownField->name); // böser cast ;) delete ownField; ownField = 0; } field = 0; } void clear() { if (strValue) *strValue = 0; strValueSize = 0; numValue = 0; floatValue = 0; memset(&timeValue, 0, sizeof(timeValue)); nullValue = 1; initialized = no; } virtual void setField(FieldDef* f) { free(); field = f; if (field) strValue = (char*)calloc(field->size+TB, sizeof(char)); } virtual FieldDef* getField() { return field; } virtual const char* getName() { return field->name; } void setValue(const char* value, int size = 0) { clear(); if (field->format != ffAscii && field->format != ffText && field->format != ffMlob) { tell(0, "Setting invalid field format for '%s', expected ASCII or MLOB", field->name); return; } if (field->format == ffMlob && !size) { tell(0, "Missing size for MLOB field '%s'", field->name); return; } if (value && size) { if (size > field->size) { tell(0, "Warning, size of %d for '%s' exeeded, got %d bytes!", field->size, field->name, size); size = field->size; } memcpy(strValue, value, size); strValue[size] = 0; strValueSize = size; nullValue = 0; } else if (value) { if (strlen(value) > (size_t)field->size) tell(0, "Warning, size of %d for '%s' exeeded [%s]", field->size, field->name, value); sprintf(strValue, "%.*s", field->size, value); strValueSize = strlen(strValue); nullValue = 0; } } void setCharValue(char value) { char tmp[2]; tmp[0] = value; tmp[1] = 0; setValue(tmp); } void setValue(int value) { setValue((long)value); } void setValue(long value) { if (field->format == ffInt || field->format == ffUInt) { numValue = value; nullValue = 0; } else if (field->format == ffDateTime) { struct tm tm; time_t v = value; memset(&tm, 0, sizeof(tm)); localtime_r(&v, &tm); timeValue.year = tm.tm_year + 1900; timeValue.month = tm.tm_mon + 1; timeValue.day = tm.tm_mday; timeValue.hour = tm.tm_hour; timeValue.minute = tm.tm_min; timeValue.second = tm.tm_sec; nullValue = 0; } else { tell(0, "Setting invalid field format for '%s'", field->name); } } void setValue(double value) { if (field->format == ffInt || field->format == ffUInt) { numValue = value; nullValue = 0; } else if (field->format == ffFloat) { floatValue = value; nullValue = 0; } else { tell(0, "Setting invalid field format for '%s'", field->name); } } int hasValue(long value) { if (field->format == ffInt || field->format == ffUInt) return numValue == value; if (field->format == ffDateTime) return no; // to be implemented! tell(0, "Setting invalid field format for '%s'", field->name); return no; } int hasValue(double value) { if (field->format == ffInt || field->format == ffUInt) return numValue == value; if (field->format == ffFloat) return floatValue == value; tell(0, "Setting invalid field format for '%s'", field->name); return no; } int hasValue(const char* value) { if (!value) value = ""; if (field->format != ffAscii && field->format != ffText) { tell(0, "Checking invalid field format for '%s', expected ASCII or MLOB", field->name); return no; } return strcmp(getStrValue(), value) == 0; } time_t getTimeValue() { struct tm tm; memset(&tm, 0, sizeof(tm)); tm.tm_isdst = -1; // force DST auto detect tm.tm_year = timeValue.year - 1900; tm.tm_mon = timeValue.month - 1; tm.tm_mday = timeValue.day; tm.tm_hour = timeValue.hour; tm.tm_min = timeValue.minute; tm.tm_sec = timeValue.second; return mktime(&tm); } unsigned long* getStrValueSizeRef() { return &strValueSize; } unsigned long getStrValueSize() { return strValueSize; } const char* getStrValue() { return !isNull() && strValue ? strValue : ""; } long getIntValue() { return !isNull() ? numValue : 0; } float getFloatValue() { return !isNull() ? floatValue : 0; } int isNull() { return nullValue; } char* getStrValueRef() { return strValue; } long* getIntValueRef() { return &numValue; } MYSQL_TIME* getTimeValueRef() { return &timeValue; } float* getFloatValueRef() { return &floatValue; } my_bool* getNullRef() { return &nullValue; } private: FieldDef* ownField; FieldDef* field; long numValue; float floatValue; MYSQL_TIME timeValue; char* strValue; unsigned long strValueSize; my_bool nullValue; int initialized; }; //*************************************************************************** // cDbStatement //*************************************************************************** class cDbStatement : public cDbService { public: cDbStatement(cDbTable* aTable); cDbStatement(cDbConnection* aConnection, const char* stmt = ""); virtual ~cDbStatement(); int execute(int noResult = no); int find(); int fetch(); int freeResult(); // interface int build(const char* format, ...); void setBindPrefix(const char* p) { bindPrefix = p; } void clrBindPrefix() { bindPrefix = 0; } int bind(cDbValue* value, int mode, const char* delim = 0); int bind(cDbTable* aTable, int field, int mode, const char* delim); int bind(int field, int mode, const char* delim = 0); int bindCmp(const char* table, cDbValue* value, const char* comp, const char* delim = 0); int bindCmp(const char* table, int field, cDbValue* value, const char* comp, const char* delim = 0); // .. int prepare(); int getAffected() { return affected; } int getResultCount(); const char* asText() { return stmtTxt.c_str(); } void showStat(); // data static int explain; // debug explain private: void clear(); int appendBinding(cDbValue* value, BindType bt); string stmtTxt; MYSQL_STMT* stmt; int affected; cDbConnection* connection; cDbTable* table; int inCount; MYSQL_BIND* inBind; // to db int outCount; MYSQL_BIND* outBind; // from db (result) MYSQL_RES* metaResult; const char* bindPrefix; int firstExec; // debug explain unsigned long callsPeriod; unsigned long callsTotal; double duration; }; //*************************************************************************** // cDbStatements //*************************************************************************** class cDbStatements { public: cDbStatements() { statisticPeriod = time(0); } ~cDbStatements() {}; void append(cDbStatement* s) { statements.push_back(s); } void remove(cDbStatement* s) { statements.remove(s); } void showStat(const char* name) { tell(0, "Statement statistic of last %ld seconds from '%s':", time(0) - statisticPeriod, name); for (std::list::iterator it = statements.begin() ; it != statements.end(); ++it) { if (*it) (*it)->showStat(); } statisticPeriod = time(0); } private: time_t statisticPeriod; std::list statements; }; //*************************************************************************** // Class Database Row //*************************************************************************** class cDbRow : public cDbService { public: cDbRow(FieldDef* f) { count = 0; fieldDef = 0; useFields(f); dbValues = new cDbValue[count]; for (int f = 0; f < count; f++) dbValues[f].setField(getField(f)); } virtual ~cDbRow() { delete[] dbValues; } void clear() { for (int f = 0; f < count; f++) dbValues[f].clear(); } virtual FieldDef* getField(int f) { return f < 0 ? 0 : fieldDef+f; } virtual int fieldCount() { return count; } void setValue(int f, const char* value, int size = 0) { dbValues[f].setValue(value, size); } void setValue(int f, int value) { dbValues[f].setValue(value); } void setValue(int f, long value) { dbValues[f].setValue(value); } void setValue(int f, double value) { dbValues[f].setValue(value); } void setCharValue(int f, char value) { dbValues[f].setCharValue(value); } int hasValue(int f, const char* value) const { return dbValues[f].hasValue(value); } int hasValue(int f, long value) const { return dbValues[f].hasValue(value); } int hasValue(int f, double value) const { return dbValues[f].hasValue(value); } cDbValue* getValue(int f) { return &dbValues[f]; } const char* getStrValue(int f) const { return dbValues[f].getStrValue(); } long getIntValue(int f) const { return dbValues[f].getIntValue(); } float getFloatValue(int f) const { return dbValues[f].getFloatValue(); } int isNull(int f) const { return dbValues[f].isNull(); } protected: virtual void useFields(FieldDef* f) { fieldDef = f; for (count = 0; (fieldDef+count)->name; count++); } int count; // field count FieldDef* fieldDef; cDbValue* dbValues; }; //*************************************************************************** // Connection //*************************************************************************** class cDbConnection { public: cDbConnection() { mysql = 0; attached = 0; inTact = no; connectDropped = yes; } virtual ~cDbConnection() { close(); } int isConnected() { return getMySql() > 0; } int attachConnection() { static int first = yes; if (!mysql) { connectDropped = yes; tell(0, "Calling mysql_init(%ld)", syscall(__NR_gettid)); if (!(mysql = mysql_init(0))) return errorSql(this, "attachConnection(init)"); if (!mysql_real_connect(mysql, dbHost, dbUser, dbPass, dbName, dbPort, 0, 0)) { close(); tell(0, "Error, connecting to database at '%s' on port (%d) failed", dbHost, dbPort); return fail; } connectDropped = no; // init encoding if (encoding && *encoding) { if (mysql_set_character_set(mysql, encoding)) errorSql(this, "init(character_set)"); if (first) { tell(0, "SQL client character now '%s'", mysql_character_set_name(mysql)); first = no; } } } attached++; return success; } void close() { if (mysql) { tell(0, "Closing mysql connection and calling mysql_thread_end(%ld)", syscall(__NR_gettid)); mysql_close(mysql); mysql_thread_end(); mysql = 0; } } void detachConnection() { attached--; if (!attached) close(); } int check() { if (!isConnected()) return fail; query("SELECT SYSDATE();"); queryReset(); return isConnected() ? success : fail; } virtual int query(const char* format, ...) { int status = 1; MYSQL* h = getMySql(); if (h && format) { char* stmt; va_list more; va_start(more, format); vasprintf(&stmt, format, more); if ((status = mysql_query(h, stmt))) errorSql(this, stmt); free(stmt); } return status ? fail : success; } virtual void queryReset() { if (getMySql()) { MYSQL_RES* result = mysql_use_result(getMySql()); mysql_free_result(result); } } virtual int executeSqlFile(const char* file) { FILE* f; int res; char* buffer; int size = 1000; int nread = 0; if (!getMySql()) return fail; if (!(f = fopen(file, "r"))) { tell(0, "Fatal: Can't access '%s'; %s", file, strerror(errno)); return fail; } buffer = (char*)malloc(size+1); while (res = fread(buffer+nread, 1, 1000, f)) { nread += res; size += 1000; buffer = srealloc(buffer, size+1); } fclose(f); buffer[nread] = 0; // execute statement tell(2, "Executing '%s'", buffer); if (query("%s", buffer)) { free(buffer); return errorSql(this, "executeSqlFile()"); } free(buffer); return success; } virtual int startTransaction() { inTact = yes; return query("START TRANSACTION"); } virtual int commit() { inTact = no; return query("COMMIT"); } virtual int rollback() { inTact = no; return query("ROLLBACK"); } virtual int inTransaction() { return inTact; } MYSQL* getMySql() { if (connectDropped && mysql) close(); return mysql; } int getAttachedCount() { return attached; } // -------------- // static stuff // set/get connecting data static void setHost(const char* s) { free(dbHost); dbHost = strdup(s); } static const char* getHost() { return dbHost; } static void setName(const char* s) { free(dbName); dbName = strdup(s); } static const char* getName() { return dbName; } static void setUser(const char* s) { free(dbUser); dbUser = strdup(s); } static const char* getUser() { return dbUser; } static void setPass(const char* s) { free(dbPass); dbPass = strdup(s); } static const char* getPass() { return dbPass; } static void setPort(int port) { dbPort = port; } static int getPort() { return dbPort; } static void setEncoding(const char* enc) { free(encoding); encoding = strdup(enc); } static const char* getEncoding() { return encoding; } int errorSql(cDbConnection* mysql, const char* prefix, MYSQL_STMT* stmt = 0, const char* stmtTxt = 0); void showStat(const char* name = "") { statements.showStat(name); } // ----------------------------------------------------------- // init() and exit() must exactly called 'once' per process static int init(key_t semKey) { if (semKey && !sem) sem = new Sem(semKey); if (!sem || sem->check() == success) { // call only once per process if (sem) tell(1, "Info: Calling mysql_library_init()"); if (mysql_library_init(0, 0, 0)) { tell(0, "Error: mysql_library_init() failed"); return fail; } } else if (sem) { tell(1, "Info: Skipping calling mysql_library_init(), it's already done!"); } if (sem) sem->inc(); // count usage per process return success; } static int exit() { mysql_thread_end(); if (sem) sem->dec(); if (!sem || sem->check() == success) { if (sem) tell(1, "Info: Released the last usage of mysql_lib, calling mysql_library_end() now"); mysql_library_end(); } else if (sem) { tell(1, "Info: The mysql_lib is still in use, skipping mysql_library_end() call"); } free(dbHost); free(dbUser); free(dbPass); free(dbName); free(encoding); return done; } MYSQL* mysql; cDbStatements statements; // all statements of this connection private: int initialized; int attached; int inTact; int connectDropped; static Sem* sem; static char* encoding; // connecting data static char* dbHost; static int dbPort; static char* dbName; // database name static char* dbUser; static char* dbPass; }; //*************************************************************************** // cDbTable //*************************************************************************** class cDbTable : public cDbService { public: cDbTable(cDbConnection* aConnection, FieldDef* f, IndexDef* i = 0); virtual ~cDbTable(); virtual const char* TableName() = 0; virtual int open(); virtual int close(); virtual int find(); virtual void reset() { reset(stmtSelect); } virtual int find(cDbStatement* stmt); virtual int fetch(cDbStatement* stmt); virtual void reset(cDbStatement* stmt); virtual int insert(); virtual int update(); virtual int store(); virtual int deleteWhere(const char* where); virtual int countWhere(const char* where, int& count, const char* what = 0); virtual int truncate(); // interface to cDbRow void clear() { row->clear(); } void setValue(int f, const char* value, int size = 0) { row->setValue(f, value, size); } void setValue(int f, int value) { row->setValue(f, value); } void setValue(int f, long value) { row->setValue(f, value); } void setValue(int f, double value) { row->setValue(f, value); } void setCharValue(int f, char value) { row->setCharValue(f, value); } int hasValue(int f, const char* value) { return row->hasValue(f, value); } int hasValue(int f, long value) { return row->hasValue(f, value); } int hasValue(int f, double value) { return row->hasValue(f, value); } const char* getStrValue(int f) const { return row->getStrValue(f); } long getIntValue(int f) const { return row->getIntValue(f); } float getFloatValue(int f) const { return row->getFloatValue(f); } int isNull(int f) const { return row->isNull(f); } FieldDef* getField(int f) { return row->getField(f); } cDbValue* getValue(int f) { return row->getValue(f); } int fieldCount() { return row->fieldCount(); } cDbRow* getRow() { return row; } cDbConnection* getConnection() { return connection; } MYSQL* getMySql() { return connection->getMySql(); } int isConnected() { return connection && connection->getMySql(); } virtual IndexDef* getIndex(int i) { return indices+i; } virtual int exist(const char* name = 0); virtual int createTable(); // static stuff static void setConfPath(const char* cpath) { free(confPath); confPath = strdup(cpath); } protected: virtual int init(); virtual int createIndices(); virtual int checkIndex(const char* idxName, int& fieldCount); virtual void copyValues(cDbRow* r); // data cDbRow* row; int holdInMemory; // hold table additionally in memory (not implemented yet) IndexDef* indices; // basic statements cDbStatement* stmtSelect; cDbStatement* stmtInsert; cDbStatement* stmtUpdate; cDbConnection* connection; // statics static char* confPath; }; //*************************************************************************** // cDbView //*************************************************************************** class cDbView : public cDbService { public: cDbView(cDbConnection* c, const char* aName) { connection = c; name = strdup(aName); } ~cDbView() { free(name); } int exist() { if (connection->getMySql()) { MYSQL_RES* result = mysql_list_tables(connection->getMySql(), name); MYSQL_ROW tabRow = mysql_fetch_row(result); mysql_free_result(result); return tabRow ? yes : no; } return no; } int create(const char* path, const char* sqlFile) { int status; char* file = 0; asprintf(&file, "%s/%s", path, sqlFile); tell(0, "Creating view '%s' using definition in '%s'", name, file); status = connection->executeSqlFile(file); free(file); return status; } int drop() { tell(0, "Drop view '%s'", name); return connection->query("drop view %s", name); } protected: cDbConnection* connection; char* name; }; //*************************************************************************** // cDbProcedure //*************************************************************************** class cDbProcedure : public cDbService { public: cDbProcedure(cDbConnection* c, const char* aName, ProcType pt = ptProcedure) { connection = c; type = pt; name = strdup(aName); } ~cDbProcedure() { free(name); } const char* getName() { return name; } int call(int ll = 1) { if (!connection || !connection->getMySql()) return fail; cDbStatement stmt(connection); tell(ll, "Calling '%s'", name); stmt.build("call %s", name); if (stmt.prepare() != success || stmt.execute() != success) return fail; tell(ll, "'%s' suceeded", name); return success; } int created() { if (!connection || !connection->getMySql()) return fail; cDbStatement stmt(connection); stmt.build("show %s status where name = '%s'", type == ptProcedure ? "procedure" : "function", name); if (stmt.prepare() != success || stmt.execute() != success) { tell(0, "%s check of '%s' failed", type == ptProcedure ? "Procedure" : "Function", name); return no; } else { if (stmt.getResultCount() != 1) return no; } return yes; } int create(const char* path) { int status; char* file = 0; asprintf(&file, "%s/%s.sql", path, name); tell(1, "Creating %s '%s'", type == ptProcedure ? "procedure" : "function", name); status = connection->executeSqlFile(file); free(file); return status; } int drop() { tell(1, "Drop %s '%s'", type == ptProcedure ? "procedure" : "function", name); return connection->query("drop %s %s", type == ptProcedure ? "procedure" : "function", name); } static int existOnFs(const char* path, const char* name) { int state; char* file = 0; asprintf(&file, "%s/%s.sql", path, name); state = fileExists(file); free(file); return state; } protected: cDbConnection* connection; ProcType type; char* name; }; //*************************************************************************** #endif //__DB_H