The poor man’s ETL (Python)

So many times I’ve been faced with a project to perform ETL (Extract, Transform, Load) between two or more databases. And everytime it grates at my soul because there’s just no good way to handle it. On the one hand, you can use some third party tool that probably comes with a hefty license fee, and comes loaded with wizards and its own custom syntax for handling external logic (Because ETL would be too easy if it was just a one-to-one transfer).

OR, you can go with the other, more popular method: Scripting the entire painstaking process by hand. You’ve probably had to deal with this, too. You need to get Table A from Database 1 into Tables B and C in Database 2. But wait, Table C takes a non-nullable value for one of its columns, but Table A allows null values. So for this one case you have to write additional logic.

As it turns out, you end up having to write additional logic for most parts of the system, and you have to maintain it. Time passes, schemas change, and so will the monkey-job of spaghetti code that handles the process.

Faced with this task yet again, I decided to try something different.

Now, I want to make it clear that I am not in any way endorsing the ORM philosophy. I will neither extol its virtues nor berate its flaws in this article. I will also not claim that this code is pretty, or should ever be used by another human being. Maybe a chimp. It’s merely a prototype that made a dull job slightly more interesting. I’m posting it here because I happen to think it was a neat way to solve the problem and there’s many places I can re-use this code with minimal changes and effort.

How it’s laid out:

def setDefaults(result):
    for row in result:
        row['campaign_name'] = 'Default Campaign'
        row['sec_mode'] = 3 if row['url'].find('https://', 0, 8 ) else 2
    return result

etl(
   self.oracle
).table(
    name='wc_active_listing', alias='wca'
).join(
    name='campaign', alias='c',
    predicate='wca.campaign_id=c.campaign_id'
).fetchAll(
   'status = 100'
).callback(setDefaults).into.database(
    self.mssql
).table(
    name='wc_active_listing', alias='bsa', primary_key=['listing_id','category_id'],
    mapping = {
        'listing_id': 'listing_id',
        'category_id': 'category_id',
        'bid': 'bid',
        [....]
    }
).commit()

It’s inspired by LINQ, but is much, much simpler. Currently it supports multiple callbacks, inner joins, and outputs to another database connection.
It’s actually more complex than it needs to be, but I encountered many issues with the pymssql package, so a few work-arounds had to be made. You’ll laugh or cry when you see them.
I’ll also note that due to our server being SQL Server 2005, we didn’t have access to MERGE either.

The full code, in all it’s hideous glory:

class etl(object):

    _instance = None

    def __init__(self, connection, result = None):
        etl._instance = connection
        self._instance = connection
        self._result = result

    def table(self, name, alias = None, result = None, mapping = None, columns = None, primary_key = None):
        if self._result is not None and mapping is not None:
            return insert(
                conn=self._instance,
                name=name,
                alias=alias,
                result=self._result,
                mapping=mapping,
                primary_key=primary_key)

        return select(
            conn=self._instance,
            name=name,
            alias=alias,
            mapping=mapping,
            columns=columns)

    def query(self, sql):
        return query(
            conn=self._instance,
            sql=sql)

    @staticmethod
    def fetchAll(sql, params = None):
        """Returns a list of dicts with columns-as-keys from a query."""
        cursor = etl._instance.cursor()
        cursor.execute(sql, params if params is not None else {})

        rows = cursor.fetchall()
        desc = cursor.description
        result = list()
        for data in rows:
            if data == None:
                return None

            #Get the field descriptors for this row
            dic = {}

            #zip returns a list of tuple where each nTh item
            #contains the nTh item of each argument.
            for (name, value) in zip(desc, data):
                dic[name[0].lower()] = value
            result.append(dic)

        print "Found %d rows" % len(result)
        print "---------------------------------------"

        return result

class select(object):

    def __init__(self, conn, name, alias, mapping, columns = None):
        column_list = '*'

        if type(columns) == type([]):
            column_list = ', '.join(columns)

        alias_sql = ' ' + alias + ' ' if alias is not None else ' '
        self.sql = "SELECT " + column_list + " FROM " + name + alias_sql
        self.database = conn
        self.name = name
        self.alias = alias

    def join(self, name, predicate, alias = None):

        alias_sql = ' ' + alias + ' ' if alias is not None else ' '
        self.sql += 'JOIN ' + name + alias_sql + 'ON (' + predicate + ') '

        return self;

    def fetchAll(self, predicate = None, params = None):
        predicate_sql = ' WHERE ' + predicate if predicate is not None else ' '
        self.sql += predicate_sql
        print "----------------------------------------"
        print "SQL: " + self.sql
        print "----------------------------------------"

        return result(etl.fetchAll(self.sql, params))

class query(object):

    def __init__(self, conn, sql):
        self.sql = sql
        self.database = conn

    def fetchAll(self, params = None):
        print self.sql
        return result(etl.fetchAll(self.sql, params))

class insert(object):
    def __init__(self, conn, name, alias, result, mapping, primary_key):
        """Blame the crappy freetds support for sql server on this mess."""
        self.database = conn
        self.cursor = self.database.cursor()

        primary_key_sql = ''
        if type(primary_key) == type(''):
            primary_key = [primary_key]

        for key in primary_key:
            primary_key_sql += key + ' = %s AND '

        primary_key_sql += '1=1' #I'm lazy

        update_sql = "UPDATE " + name + " SET %s WHERE "
        insert_sql = "INSERT INTO " + name + " (%s) VALUES (%s)"
        select_sql = "SELECT 1 FROM " + name + " WHERE "

        i = 0
        for row in result:
            i = i+1
            nMap = dict()
            for (key, val) in row.items():
                if key in mapping:
                    nMap[ mapping[key] ] = str(val)

            #Because pymssql sucks and rowcount doesn't work, we have to do a select
            self.cursor.execute(select_sql + primary_key_sql, tuple([nMap[key] for key in primary_key]))
            row_exists = True if self.cursor.fetchone() is not None else False

            print select_sql + primary_key_sql % tuple(nMap[key] for key in primary_key)

            if row_exists:
                #Do an update first, then check rows affected
                col_map_list = [x + '=%s' for x in nMap.keys()]
                col_map = ', '.join(col_map_list)

                print "-----------------------------------------"
                print "Row #%d" % i
                print "SQL: " + ((update_sql % col_map) + primary_key_sql) % (tuple(nMap.values() + [nMap[key] for key in primary_key]))

                self.cursor.execute(
                    ((update_sql % col_map) + primary_key_sql),
                    tuple(nMap.values() + [nMap[key] for key in primary_key])
                )
                print "Affected: " + str(self.cursor._source.rows_affected)

            else:
                #Do the insert if nothing was updated
                col_map = ("%s, " * len(nMap)).rstrip(', ')
                print "SQL: " + (insert_sql % (", ".join(nMap.keys()), col_map)) % tuple(nMap.values())

                try:
                    self.cursor.execute(
                        insert_sql % (", ".join(nMap.keys()), col_map),
                        tuple(nMap.values())
                    )
                except Exception as e:
                    print e

                print "-----------------------------------------"

    def commit(self):
        return self.database.commit()

    def rollback(self):
        return self.database.rollback()

class result(object):
    def __init__(self, res):
        self._result = res
        self.into = into(res)

    def callback(self, callback):
        self.into = into(callback.__call__(self._result))
        return self

class into(object):
    def __init__(self, result):
        self._result = result

    def database(self, connection):
        return etl(
            connection=connection,
            result=self._result)

One comment

  1. Helpful blog, bookmarked the website with hopes to read more!

Leave a comment