"""
sqlWrap - Adds convenience methods to DB-API2 connection objects.

Sep. 22, 2005 by Catherine Devlin (catherine.devlin@gmail.com, http://catherinedevlin.blogspot.com/)
Heavily revised Feb.2006 based on ideas found in the Python Cookbook (2nd Ed.)

Purpose of This Module
======================

Extends DB-API2 Connection objects with convenience methods
for issuing SQL, accessing results, and reporting.

This module is NOT intended as an object/relational mapper.
Although it does help a programmer move data between Python
and a database with less code, it remains closely tied to SQL operations,
instead of trying to mask SQL behind a mapping interface.

Most importantly, sqlWrap only extends the capabilities of DB-API2;
all original DB-API2 functionality remains.  You should be able to simply
replace your DB-API2 Connection objects with sqlWrap Connection objects
and then continue programming, forgetting that you're using sqlWrap except
when you want to use its features.

Currently supports: Oracle (cx_Oracle), sqlite (pysqlite), postgresql (psycopg)

How To Use This Module
======================

1. Import it: ``import sqlWrap`` or ``from sqlWrap import *``

2. Create a connection to your database:

        cnxn = OraConnection('hr/hr@xe')
        
        cnxn = SqliteConnection('database.sqlite')
        
        cnxn = sqlWrap.PostgresConnection('dbname=pgdb user=postgres password=pwd')
        
3. Perform queries

  a) using the connection's `select` method:

        curs = cnxn.select('employees')
        
    i) Clauses like whereClause can be supplied as Python dictionaries,
       causing automatic use of bind variables.
       
       curs = cnxn.select('employees', whereClause={'department_id':90})

  b) Using standard DB-API2 techniques:
  
        curs = cnxn.cursor()
        
        curs.execute('SELECT * FROM employees')
        
4. Cursors return rows with many conveniences.

        row = curs.next()
        
  a) Dictionary-style access
  
        row['last_name']
        
      Keys are case-insensitive
      
        row['LAST_NAME']
    
  b) Object-style access
  
        row.last_name
        
        row.LAST_NAME
        
5. Cursors have methods for formatted reporting

        - curs.xml()        
        - curs.xmlTransposed()
        - curs.pp()
        - curs.ppTransposed()
        - curs.xhtml()
        - curs.xhtmlTransposed()
        - curs.rst()
        - curs.rstTransposed()
        
6. Convenience methods for insert, update, delete:

        cnxn.insert('regions',{'region_id':5,'region_name':'Antarctica'})
        
        cnxn.update('regions', setClause = {'region_name':'Antarctica'},
        whereClause = {'region_id':5})
        
        cnxn.delete('regions',{'region_id':5})
        
        As always with DB-API2, DML requires explicitly committing transactions.
        
        cnxn.commit()

Incidental benefits
======================

When calling curs.execute, bind variables may be supplied either by the
DB-API2 package's own convention ("Qmark" for pysqlite, 'Pyformat' for
psycopg) or by the Named (cx_Oracle) format:

        curs.execute('SELECT * FROM regions WHERE region_id = :r', {'r': 2})
        
In native cx_Oracle, any "extra" bind variables supplied in the bind variable
dictionary raise errors.  sqlWrap fixes this, letting you use dictionaries like
locals() to pass bind variables.

        curs.execute('SELECT * FROM regions WHERE region_id = :r', {'r': 2, 'a':1})

TODO: trap common errors (like neglecting selectList)
TODO: can submit clauses as objects?
TODO: support executemany
TODO: use password module
TODO: error when curs.fetchone() on empty sqlite result - description is None
"""
__docformat__ = 'restructuredtext'

import string, re, itertools, logging
from xml.sax import saxutils
from datetime import datetime

logger = logging.getLogger('sqlwrap')
logger.info('importing sqlWrap.py: __doc__ == %s, __name__ == %s' % (__doc__, __name__))

def comparitor(value):
    """arguments: value
       returns '=' if there is no wildcard in value, 'LIKE' if there is"""
    if isinstance(value, str) and '%' in value:
        return 'LIKE'
    return '='
    
class ColumnWriter(object):
    def __init__(self, colVals):
        #colVals of None is OK for an insert
        try:
            self.colValList = colVals.items()
            self.colValDict = colVals
        except AttributeError:
            self.colValList = colVals.__dict__.items()
            self.colValDict = colVals.__dict__
        self.oldNameMap = {}
    def conditionList(self):
        return " AND ".join(c for c in self.compareStatement())
    def colNameList(self):
        result = []
        for (col, val) in self.colValList:
            oldColName = self.oldNameMap.get(col)
            result.append(oldColName or col)
        return ", ".join(result)
    def placeholderList(self):
        return ", ".join(p for p in self.placeholder())

class FormatColumnWriter(ColumnWriter):
    def placeholder(self):
        """
        >>> print FormatColumnWriter({'col1':'val1'}).placeholder().next()
        %s
        """
        for colVal in self.colValList:
            yield '%s'
    def bindVars(self):
        """
        >>> print FormatColumnWriter({'col1':'val1'}).bindVars()
        ['val1']
        """
        return [v[1] for v in self.colValList]
    def setStatement(self):
        """
        >>> print FormatColumnWriter({'col1':'val1'}).setStatement().next()
        col1 = %s
        """
        for colVal in self.colValList:
            yield '%s = %%s' % (colVal[0])
    def compareStatement(self):
        """
        >>> print FormatColumnWriter({'col1':'val1'}).compareStatement().next()
        col1 = %s
        >>> print FormatColumnWriter({'col1':'val%'}).compareStatement().next()
        col1 LIKE %s
        """
        for (col, val) in self.colValList:
            yield '%s %s %%s' % (col, comparitor(val))
    def reconcileTo(self, other):
        pass  
    
class QmarkColumnWriter(ColumnWriter):
    def placeholder(self):
        """
        >>> print QmarkColumnWriter({'col1':'val1'}).placeholder().next()
        ?
        """
        for colVal in self.colValList:
            yield '?'
    def bindVars(self):
        """
        >>> print QmarkColumnWriter({'col1':'val1'}).bindVars()
        ['val1']
        """
        return [v[1] for v in self.colValList]
    def setStatement(self):
        """
        >>> print QmarkColumnWriter({'col1':'val1'}).setStatement().next()
        col1 = ?
        """
        for colVal in self.colValList:
            yield '%s = ?' % (colVal[0])
    def compareStatement(self):
        """
        >>> print QmarkColumnWriter({'col1':'val1'}).compareStatement().next()
        col1 = ?
        >>> print QmarkColumnWriter({'col1':'val%'}).compareStatement().next()
        col1 LIKE ?
        """
        for (col, val) in self.colValList:
            yield '%s %s ?' % (col, comparitor(val))
    def reconcileTo(self, other):
        pass  

class DictBasedColumnWriter(ColumnWriter):
    def placeholder(self):
        """
        >>> print NamedColumnWriter({'col1':'val1'}).placeholder().next()
        :col1
        """
        for colVal in self.colValList:
            yield self.marker(colVal[0])
    def bindVars(self):
        """
        >>> print NamedColumnWriter({'col1':'val1'}).bindVars()
        {'col1': 'val1'}
        """
        return self.colValDict
    def setStatement(self):
        """
        >>> print NamedColumnWriter({'col1':'val1'}).setStatement().next()
        col1 = :col1
        """
        for (col, val) in self.colValList:
            oldColName = self.oldNameMap.get(col)
            yield '%s = %s' % (oldColName or col, self.marker(col))
    def compareStatement(self):
        """
        >>> print NamedColumnWriter({'col1':'val1'}).compareStatement().next()
        col1 = :col1
        >>> print NamedColumnWriter({'col1':'val%'}).compareStatement().next()
        col1 LIKE :col1
        """
        for (col, val) in self.colValList:
            oldColName = self.oldNameMap.get(col)
            yield '%s %s %s' % (oldColName or col, comparitor(val), self.marker(col))
    def reconcileTo(self, other):
        """
        Stirs in the col:val pairs from another column handler, without stepping on my own col:val pairs.
        
        >>> ch = NamedColumnWriter({'col1':'val1', 'xcol1':'val2', 'col3':'val3'})
        >>> ch2 = NamedColumnWriter({'col1':'val1a', 'col4':'val4'})
        >>> ch2.reconcileTo(ch)
        >>> cvl = ch2.colValList
        >>> cvl.sort()
        >>> print cvl
        [('col4', 'val4'), ('xxcol1', 'val1a')]
        >>> print ch2.oldNameMap
        {'xxcol1': 'col1'}
        >>> setStmts = [ss for ss in ch2.setStatement()]
        >>> setStmts.sort()
        >>> print setStmts
        ['col1 = :xxcol1', 'col4 = :col4']
        """
        if self.oldNameMap:
            raise "Multiple calls to reconcileTo not supported."
        for col in self.colValDict.keys():
            newColName = col
            while newColName in other.colValDict.keys():
                newColName = 'x' + newColName
            if newColName != col:
                self.oldNameMap[newColName] = col
                self.colValDict[newColName] = self.colValDict.pop(col)
        if self.oldNameMap:
            self.colValList = self.colValDict.items()         

class NamedColumnWriter(DictBasedColumnWriter):
    def marker(self, colName):
        return ':%s' % (colName)
    
class PyformatColumnWriter(DictBasedColumnWriter):
    def marker(self, colName):
        return '%%(%s)s' % (colName)

def purgeUnprintable(s):
    """Removes unprintable characters from a string."""
    # Thanks to Bengt Richter for this one, at http://www.thescripts.com/forum/thread101824.html
    identity = ''.join(chr(i) for i in xrange(256))
    unprintable = ''.join(c for c in identity if c not in string.printable)
    return s.translate(identity, unprintable)     

class Bunch:
    """A generic object, with attributes but no methods."""
    def __init__(self, d):
        self.__dict__.update(d)

def noTransform(strng):
    return strng
def htmlTableTagTransform(strng):
    if strng.lower() in ('tr', 'table'):
        return strng
    return 'td'

class CaselessDict(dict):
    """dict with case-insensitive keys.

    Posted to ASPN Python Cookbook by Jeff Donner - http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66315"""
    def __init__(self, other=None):
        if other:
            # Doesn't do keyword args
            if isinstance(other, dict):
                for k,v in other.items():
                    dict.__setitem__(self, k.lower(), v)
            else:
                for k,v in other:
                    dict.__setitem__(self, k.lower(), v)

    def __getitem__(self, key):
        return dict.__getitem__(self, key.lower())

    def __setitem__(self, key, value):
        dict.__setitem__(self, key.lower(), value)

    def __contains__(self, key):
        return dict.__contains__(self, key.lower())

    def has_key(self, key):
        return dict.has_key(self, key.lower())

    def get(self, key, def_val=None):
        return dict.get(self, key.lower(), def_val)

    def setdefault(self, key, def_val=None):
        return dict.setdefault(self, key.lower(), def_val)

    def update(self, other):
        for k,v in other.items():
            dict.__setitem__(self, k.lower(), v)

    def fromkeys(self, iterable, value=None):
        d = CaselessDict()
        for k in iterable:
            dict.__setitem__(d, k.lower(), value)
        return d

    def pop(self, key, def_val=None):
        return dict.pop(self, key.lower(), def_val)
    
class Row(tuple):
    def __new__(cls, t, cursor):
        return tuple.__new__(cls, t or ())
    def __init__(self, t, cursor):
        super(Row, self).__init__(t or ())
        self.cursor = cursor
        self.dict = CaselessDict(itertools.izip(cursor.colNames(), self))
    def __getitem__(self, item):
        try:
            return super(Row, self).__getitem__(item)
        except TypeError:
            return self.dict[item]
    def __getattribute__(self, attrib):
        try:
            return super(Row, self).__getattribute__(attrib)
        except AttributeError:
            try:
                return self.dict.__getattribute__(attrib)
            except AttributeError:
                try:
                    return self.__dict__.__getattribute__(attrib)
                except AttributeError:
                    return self.dict[attrib]
    def colValPairs(self):
        return itertools.izip(self.cursor.colNames(), self)
    def object(self):
        """Returns an empty object whose attributes correspond to fields in the Row."""
        return Bunch(self.dict)
    def xml(self, tagTransform = noTransform, rowTag = None):
        """Produces a series of XML tags (1 per column) from a DB-API2 cursor and one of its result rows.
        
        tagTransform function, if supplied, is applied to each tag; 
        for instance, str.upper or str.title.
        Row overall is tagged with rowTag argument, or (if omitted) 
        with name of source table."""
        result = []
        if rowTag is None:
            rowTag = self.cursor.source
        if rowTag:
            result.append("  <%s>" % tagTransform(rowTag))
        for (tag, val) in self.colValPairs():
            tag = tagTransform(tag)
            val = val or ''
            val = purgeUnprintable(saxutils.escape(str(val)))
            result.append("    <%s>%s</%s>" % (tag, val, tag))
        if rowTag:
            result.append("  </%s>" % tagTransform(rowTag))
        return "\n".join(result)
    def xhtml(self):
        return self.xml(tagTransform = htmlTableTagTransform, rowTag = 'tr')
    def pp(self):
        return self.cursor.format % self
    def sql(self):
        """Generates an INSERT statement to duplicate the row's data.
        
        Warning: date, time, and datetime fields are not likely to work properly."""
        vals = []
        for val in self:
            if val is None or val is '':
                val = 'NULL'
            elif isinstance(val, datetime):
                val = self.cursor.writeDate(val)
            elif hasattr(val, 'swapcase'):
                val = "'%s'" % (val.replace("'","''"))
            vals.append(str(val))
        return """INSERT INTO %s (%s) VALUES (%s);""" % (self.cursor.source, ", ".join(self.cursor.colNames()), ", ".join(vals))
        
findTableNameRe = re.compile(r"\s*SELECT\s+.*\s+FROM\s+([a-zA-Z$_#]*)", re.IGNORECASE | re.MULTILINE | re.DOTALL)        
class CursorMixin(object):
    def execute(self, statement, params={}, giveEmpty=False):
        self.giveEmpty = giveEmpty
        tableNameMatch = findTableNameRe.search(statement)
        self.source = tableNameMatch and tableNameMatch.groups()[0]
        logMsg = 'statement:\n%s\nwith parameters:\n%s' % (statement, params)
        try:
            result = self.baseClass.execute(self, statement, params)
            logger.debug(logMsg)
        except Exception, exceptMsg:
            logger.error('error:\n%s\nduring execution of %s' % (exceptMsg, logMsg))
            raise
        self.setLengths()
        def giveRows(result):
            it = iter(result)
            yield Row(it.next(), self)
        return giveRows(result)
    def leftJoinFetchone(self):
        try:
            return self.fetchone()
        except StopIteration:
            result = [((d[1] == datetime) and datetime.utcfromtimestamp(0)) or None for d in self.description]
            return (Row(result), self)
    def next(self):
        return Row(self.baseClass.next(self), self)
    def fetchone(self):
        return Row(self.baseClass.fetchone(self), self)
    def fetchall(self):
        return list(self.fetchsome())
    def fetchmany(self, arraysize = 1000):
        return Row(self.baseClass.fetchmany(self, arraysize), self)
    def fetchsome(self, arraysize = 1000):
        """A generator simplifying the use of fetchmany.
        From Python Cookbook, 2nd Ed., recipe 19.13"""
        while True:
            results = self.fetchmany(arraysize)
            if not results:
                break
            for result in results:
                yield Row(result, self)    
    def colNames(self):
        return [d[0] for d in self.description]
    def writeDate(self, val):
        return "'%s'" % str(val)
    def setLengths(self, lengths=[]):
        if self.description:
            colNameLengths = [len(cn) for cn in self.colNames()]
            dataLengths = lengths or [d[2] or len(d[0]) for d in self.description]
            self.lengths = [max(pair) for pair in itertools.izip(colNameLengths, dataLengths)]
            self.format = " ".join("%%-%ss" % l for l in self.lengths)
    def shrinkFieldLengths(self):
        """overrides defined field lengths, replacing them with actual maximum data length."""
        if self.data:
            lengths = [[len(str(r[col])) for r in self.data] for col in range(len(self.data[0]))]
            maxRowLengths = [max(l) for l in lengths]
            self.setLengths(maxRowLengths)
    def markerLine(self, marker = '-'):
        return self.format % tuple(''.rjust(l, marker) for l in self.lengths)                    
    def pp(self, edgeMark='', headerSep = '-'):
        """Pretty-printer for a whole result set.  """
        self.data = self.fetchall()
        self.shrinkFieldLengths()
        result = []
        result.append(self.format % tuple(self.colNames()))
        result.append(self.markerLine(headerSep))
        result.extend(row.pp() for row in self.data)
        if edgeMark:
            edge = self.markerLine(edgeMark)
            result.insert(0, edge)
            result.append(edge)
        return "\n".join(result)
    def rst(self):
        """ReStructured Text table for a whole result set."""
        return self.pp(edgeMark = '=', headerSep = '-')
    def ppTransposed(self, edgeMark=''):
        """Pretty-printed result set, transposed (column names vertically)"""
        data = self.fetchall()
        maxColNameLen = max(len(cn) for cn in self.colNames())
        colNameFormat = "%%%ss" % (maxColNameLen)
        ppLengths = []
        for rowNum in range(len(data)):
            ppLengths.append(max(len(str(datum)) for datum in data))        
        ppFormat = " ".join("%%-%ss" % l for l in ppLengths)
        result = []
        for colNum, colName in enumerate(self.colNames()):
            result.append("%s %s" % (colNameFormat % colName, 
                                                         ppFormat % tuple(row[colNum] for row in data)))
        if edgeMark:
            edge = "%s %s" % (colNameFormat % ''.rjust(maxColNameLen, edgeMark), ppFormat % tuple(''.rjust(l, edgeMark) for l in ppLengths))
            result.insert(0, edge)
            result.append(edge)
        return "\n".join(result)
    def rstTransposed(self):
        return self.ppTransposed(edgeMark = '=')
    def xml(self, rowTag = None, tagTransform = noTransform):
        result = [row.xml(tagTransform=tagTransform, rowTag=rowTag) for row in self.fetchsome()]
        return "\n".join(result)
    def xhtmlTh(self):
        result = ["  <tr>"]
        for colName in self.colNames():
            result.append("    <th>%s</th>" % (colName))
        result.append("  </tr>")
        return result    
    def xhtml(self):
        result = ['<table>']
        result.extend(self.xhtmlTh())
        result.extend(row.xhtml() for row in self.fetchsome())
        result.append('</table>')
        return "\n".join(result)
    def xhtmlTransposed(self):
        """An HTML table of the result set, transposed - occasionally nice when you have few rows and many columns"""
        result = ['<table>']
        data = self.fetchall()
        for colNum, colName in enumerate(self.colNames()):
            colData = ['<td>%s</td>' % str(row[colNum]) for row in data]
            result.append('  <tr>\n    <td>%s</td>%s\n  </tr>' % (colName, "".join(colData)))
        result.append('</table>')
        return "\n".join(result)
        
def combine(arg1, arg2):
    """Updates dict arg1 with dict arg2, or extends list arg1 by arg2, depending on what arg1 and arg2 happen to support."""
    try:
        arg1.update(arg2)
    except AttributeError:
        arg1.extend(arg2)
            
class AnyDbConnection(object):
    def cursor(self):
        return self.customCursor(self)
    def selectSQL(self, source, selectUs='*', whereClause = {}, userBindVars = None, orderBy = None):
        if isinstance(selectUs, list):
            selectUs = ", ".join(selectUs)
        sql = ['SELECT %s FROM %s' % (selectUs, source)]
        bindVars = userBindVars or self.bindVarFormat()
        if whereClause:
            if isinstance(whereClause, dict) or isinstance(whereClause, Bunch):
                whereColHandler = self.columnWriter(whereClause)
                sql.append(' WHERE %s' % (whereColHandler.conditionList()))
                bindVars = whereColHandler.bindVars()
            else:
                sql.append(' WHERE %s' % (whereClause))
        if orderBy:
            sql.append('ORDER BY %s' % (orderBy))
        return ("\n".join(sql), bindVars)
    
    def select(self, source, selectList='*', whereClause = {}, userBindVars = None, orderBy = None):
        """Builds and executes a SELECT statement; returns a cursor.

        Parameters:
        
        - `source`: Source table name
        - `selectList`: Names of columns to return.  May be a string or a list of strings.
          Defaults to '*' (all columns)
        - `whereClause`: Filtering condition.  May be a dict or a string;
          "col1='a' and col2=22" and {'col1':'a',col2:22}are equivalent,
          but passing a dict automatically causes bind variables to be used.
          If a value contains % (SQL wildcard), LIKE will be used instead of =.
        - `userBindVars` - If whereClause is given as a string with bind variables,
          supply them as a dict.
        - 'orderBy` - Columns to order results by.Given as a string.  Ex: "col1, col2"
        """
        curs = self.cursor()
        (sql, bindVars) = self.selectSQL(source, selectList, whereClause, userBindVars, orderBy)
        curs.execute(sql, bindVars)
        return curs
    def describe(self, source):
        curs = self.select(source, "*", "1=0")
        result = curs.colNames()
        curs.close()
        return result
    def insertSQL(self, target, setClause):
        if isinstance(setClause, list):
            setClause = dict(itertools.izip(self.describe(target), setClause))
        columnWriter = self.columnWriter(setClause)
        sql = 'INSERT INTO %s (%s) VALUES (%s)' % (target, columnWriter.colNameList(), columnWriter.placeholderList())
        return (sql, columnWriter.bindVars())
    def updateSQL(self, target, setClause, whereClause = {}, userBindVars = None):
        if isinstance(setClause, dict) or isinstance(setClause, Bunch):
            colHandler = self.columnWriter(setClause)
            setStatements = ", ".join(ss for ss in colHandler.setStatement())
            bindVars = colHandler.bindVars()
        else:
            setStatements = setClause
            bindVars = userBindVars or self.bindVarFormat()
        sql = 'UPDATE %s SET %s' % (target, setStatements)
        if whereClause:
            if isinstance(whereClause, dict) or isinstance(whereClause, Bunch):
                whereColHandler = self.columnWriter(whereClause)
                whereColHandler.reconcileTo(colHandler)
                combine(bindVars, whereColHandler.bindVars())
                sql = "%s WHERE %s" % (sql, whereColHandler.conditionList())
            else: 
                sql = "%s WHERE %s" % (sql, whereClause)
                if not userBindVars: # only add userBindVars to bindVars once
                    combine(bindVars, userBindVars)
        return (sql, bindVars)
    def deleteSQL(self, target, whereColVals = {}):
        sql = 'DELETE FROM %s' % (target)
        whereColHandler = self.columnWriter(whereColVals)
        if whereColVals:
            sql = "%s  WHERE %s" % (sql, whereColHandler.conditionList())
        return (sql, whereColHandler.bindVars())
    def dml(self, sql, bindVars):
        curs = self.cursor()
        curs.execute(sql, bindVars)
        result = curs.rowcount
        curs.close()
        return result
    def insert(self, target, setClause):
        """Inserts to a table.  Returns number of rows inserted.

        Parameters:
        
        - `target`: Table to insert into
        - `setClause`: A dict containing column:value data to insert.
          Columns not included will be set to NULL.
        """
        (sql, bindVars) = self.insertSQL(target, setClause) 
        return self.dml(sql, bindVars)
    def update(self, target, setClause, whereClause={}, userBindVars = None):
        """Updates a table.  Returns number of rows updated.

        Parameters:
        
        - `target`: Table to update.
        - `setClause`: A dict containing column:value data to set.
        - `whereClause`: Filtering condition.  May be a dict or a string;
          "col1='a' and col2=22" and {'col1':'a',col2:22}are equivalent,
          but passing a dict automatically causes bind variables to be used.
          If a value contains % (SQL wildcard), LIKE will be used instead of =.
        - `userBindVars` - If whereClause is given as a string with bind variables,
          supply them as a dict.
        """
        (sql, bindVars) = self.updateSQL(target, setClause, whereClause, userBindVars)
        return self.dml(sql, bindVars)
    def delete(self, target, whereClause={}):
        """Deletes from a table.  Returns number of rows deleted.

        Parameters:
        
        - `target`: Table to delete from.
        - `whereClause`: Filtering condition.  May be a dict or a string;
          "col1='a' and col2=22" and {'col1':'a',col2:22}are equivalent,
          but passing a dict automatically causes bind variables to be used.
          If a value contains % (SQL wildcard), LIKE will be used instead of =.
        """
        (sql, bindVars)  = self.deleteSQL(target, whereClause)
        return self.dml(sql, bindVars)
    
try:
    from pysqlite2 import dbapi2 as sqlite
    class SqliteCustomCursor(CursorMixin, sqlite.Cursor):
        baseClass = sqlite.Cursor
    class SqliteConnection(AnyDbConnection, sqlite.Connection):   
        """Extends pysqlite.dbapi2 Connection object; see http://www.sqlite.org/
        adds convenience methods: select, insert, update, delete, genericSelect"""        
        columnWriter = QmarkColumnWriter 
        bindVarFormat = list
        customCursor = SqliteCustomCursor
except ImportError:
    pass

try:
    import psycopg2
    class PostgresCursor(CursorMixin, psycopg2._psycopg.cursor):
        baseClass = psycopg2._psycopg.cursor
    class PostgresConnection(AnyDbConnection, psycopg2._psycopg.connection):
        """Extends psycopg connection object; see http://www.zope.org/Members/fog/psycopg
        adds convenience methods: select, insert, update, delete, genericSelect"""                
        columnWriter = PyformatColumnWriter 
        bindVarFormat = dict
        customCursor = PostgresCursor         
except ImportError:
    pass        

"""
try:
    import MySQLdb
    class MySQLCursor(CursorMixin, MySQLdb.cursors.Cursor):
        baseClass = MySQLdb.cursors.Cursor
    class MySQLConnection(AnyDbConnection, MySQLdb.connections.Connection):
        Extends MySQLdb connection object
        adds convenience methods: select, insert, update, delete, genericSelect       
        columnWriter = FormatColumnWriter 
        bindVarFormat = dict
        customCursor = MySQLCursor         
except ImportError:
    pass        
"""            
            
try:
    import cx_Oracle
    findBindVarRe = re.compile(r':{?\w+}?\b')
    class OraCustomCursor(CursorMixin, cx_Oracle.Cursor):
        baseClass = cx_Oracle.Cursor
        def execute(self, statement, params={}):
            if type(params) == list:
                bindVars = params
            else:
                bindVars = params.copy()
                bindVarsUsed = [bv.strip(':{}') for bv in findBindVarRe.findall(statement)]
                for param in params.keys():
                    if param not in bindVarsUsed:
                        bindVars.pop(param)
            super(OraCustomCursor, self).execute(statement, bindVars)
        def writeDate(self, val):
            return "TO_DATE('%s', 'YYYY-MM-DD HH24:MI:SS')" % str(val)
    class OraConnection(AnyDbConnection, cx_Oracle.Connection):
        """Extends cx_Oracle Connection object; see http://www.computronix.com/utilities.shtml#Oracle
        adds convenience methods: select, insert, update, delete, genericSelect"""
        columnWriter = NamedColumnWriter
        bindVarFormat = dict
        customCursor = OraCustomCursor
            
        def metadata(self, objType, schema, objName = None):
            dep = 'get_dependent_ddl(:objType,:objName,:schema)'
            granted = 'get_granted_ddl(:objType, :schema)'
            sch = 'get_ddl(:objType,:objName,:schema)'
            metadataProgramMap = {
                'OBJECT_GRANT': granted, 'SYSTEM_GRANT': granted, 'ROLE_GRANT': granted, 'DEFAULT_ROLE': granted, 'PROXY': granted, 'TABLESPACE_QUOTA': granted,
                'COMMENT': dep, 'CONSTRAINT': dep, 'INDEX': dep, 'INDEX_STATISTICS': dep, 'MATERIALIZED_VIEW_LOG': dep, 'REF_CONSTRAINT': dep, 'TABLE_DATA': dep, 'TRIGGER': dep,
                'SYNONYM': sch, 'TABLE': sch, 'CLUSTER': sch, 'DB_LINK': sch, 'DIMENSION': sch, 'FUNCTION': sch, 'INDEXTYPE': sch,
                'JAVA_SOURCE': sch, 'LIBRARY': sch, 'MATERIALIZED_VIEW': sch, 'OPERATOR': sch, 'PACKAGE': sch, 'PACKAGE_SPEC': sch,
                'PACKAGE_BODY': sch, 'REFRESH_GROUP': sch, 'ROLE': sch, 'SEQUENCE': sch, 'TYPE': sch, 'TYPE_SPEC': sch, 'TYPE_BODY': sch, 'VIEW': sch, 'XMLSCHEMA': sch}
            metadataProgram = metadataProgramMap[objType]
            curs = self.cursor()
            curs.execute('SELECT dbms_metadata.%s FROM dual' % metadataProgram, locals())
            try:
                raw = curs.fetchone()
            except (cx_Oracle.DatabaseError), errmsg:
                if 'ORA-31608' in str(errmsg):  #no objects meet the criteria
                    curs.close()
                    return ''
                else:
                    raise
            curs.close()
            return str(raw[0]).strip()
except ImportError:
    pass

def _test():
    import doctest, sqlWrap, unittest
    doctest.testmod(sqlWrap)
    suite = doctest.DocFileSuite('sqlWrap_examples.txt')
    unittest.TextTestRunner().run(suite)
    
if __name__ == "__main__":
    "Silent return implies that all unit tests succeeded.  Use -v to see details."
    #_test()