So many times I’ve been faced with a project to perform ETL (Extract, Transform, Load) between two or more databases. And everytime it grates at my soul because there’s just no good way to handle it. On the one hand, you can use some third party tool that probably comes with a hefty license fee, and comes loaded with wizards and its own custom syntax for handling external logic (Because ETL would be too easy if it was just a one-to-one transfer).
OR, you can go with the other, more popular method: Scripting the entire painstaking process by hand. You’ve probably had to deal with this, too. You need to get Table A from Database 1 into Tables B and C in Database 2. But wait, Table C takes a non-nullable value for one of its columns, but Table A allows null values. So for this one case you have to write additional logic.
As it turns out, you end up having to write additional logic for most parts of the system, and you have to maintain it. Time passes, schemas change, and so will the monkey-job of spaghetti code that handles the process.
Faced with this task yet again, I decided to try something different.
Now, I want to make it clear that I am not in any way endorsing the ORM philosophy. I will neither extol its virtues nor berate its flaws in this article. I will also not claim that this code is pretty, or should ever be used by another human being. Maybe a chimp. It’s merely a prototype that made a dull job slightly more interesting. I’m posting it here because I happen to think it was a neat way to solve the problem and there’s many places I can re-use this code with minimal changes and effort.
How it’s laid out:
def setDefaults(result):
for row in result:
row['campaign_name'] = 'Default Campaign'
row['sec_mode'] = 3 if row['url'].find('https://', 0, 8 ) else 2
return result
etl(
self.oracle
).table(
name='wc_active_listing', alias='wca'
).join(
name='campaign', alias='c',
predicate='wca.campaign_id=c.campaign_id'
).fetchAll(
'status = 100'
).callback(setDefaults).into.database(
self.mssql
).table(
name='wc_active_listing', alias='bsa', primary_key=['listing_id','category_id'],
mapping = {
'listing_id': 'listing_id',
'category_id': 'category_id',
'bid': 'bid',
[....]
}
).commit()
It’s inspired by LINQ, but is much, much simpler. Currently it supports multiple callbacks, inner joins, and outputs to another database connection.
It’s actually more complex than it needs to be, but I encountered many issues with the pymssql package, so a few work-arounds had to be made. You’ll laugh or cry when you see them.
I’ll also note that due to our server being SQL Server 2005, we didn’t have access to MERGE either.
The full code, in all it’s hideous glory:
class etl(object):
_instance = None
def __init__(self, connection, result = None):
etl._instance = connection
self._instance = connection
self._result = result
def table(self, name, alias = None, result = None, mapping = None, columns = None, primary_key = None):
if self._result is not None and mapping is not None:
return insert(
conn=self._instance,
name=name,
alias=alias,
result=self._result,
mapping=mapping,
primary_key=primary_key)
return select(
conn=self._instance,
name=name,
alias=alias,
mapping=mapping,
columns=columns)
def query(self, sql):
return query(
conn=self._instance,
sql=sql)
@staticmethod
def fetchAll(sql, params = None):
"""Returns a list of dicts with columns-as-keys from a query."""
cursor = etl._instance.cursor()
cursor.execute(sql, params if params is not None else {})
rows = cursor.fetchall()
desc = cursor.description
result = list()
for data in rows:
if data == None:
return None
#Get the field descriptors for this row
dic = {}
#zip returns a list of tuple where each nTh item
#contains the nTh item of each argument.
for (name, value) in zip(desc, data):
dic[name[0].lower()] = value
result.append(dic)
print "Found %d rows" % len(result)
print "---------------------------------------"
return result
class select(object):
def __init__(self, conn, name, alias, mapping, columns = None):
column_list = '*'
if type(columns) == type([]):
column_list = ', '.join(columns)
alias_sql = ' ' + alias + ' ' if alias is not None else ' '
self.sql = "SELECT " + column_list + " FROM " + name + alias_sql
self.database = conn
self.name = name
self.alias = alias
def join(self, name, predicate, alias = None):
alias_sql = ' ' + alias + ' ' if alias is not None else ' '
self.sql += 'JOIN ' + name + alias_sql + 'ON (' + predicate + ') '
return self;
def fetchAll(self, predicate = None, params = None):
predicate_sql = ' WHERE ' + predicate if predicate is not None else ' '
self.sql += predicate_sql
print "----------------------------------------"
print "SQL: " + self.sql
print "----------------------------------------"
return result(etl.fetchAll(self.sql, params))
class query(object):
def __init__(self, conn, sql):
self.sql = sql
self.database = conn
def fetchAll(self, params = None):
print self.sql
return result(etl.fetchAll(self.sql, params))
class insert(object):
def __init__(self, conn, name, alias, result, mapping, primary_key):
"""Blame the crappy freetds support for sql server on this mess."""
self.database = conn
self.cursor = self.database.cursor()
primary_key_sql = ''
if type(primary_key) == type(''):
primary_key = [primary_key]
for key in primary_key:
primary_key_sql += key + ' = %s AND '
primary_key_sql += '1=1' #I'm lazy
update_sql = "UPDATE " + name + " SET %s WHERE "
insert_sql = "INSERT INTO " + name + " (%s) VALUES (%s)"
select_sql = "SELECT 1 FROM " + name + " WHERE "
i = 0
for row in result:
i = i+1
nMap = dict()
for (key, val) in row.items():
if key in mapping:
nMap[ mapping[key] ] = str(val)
#Because pymssql sucks and rowcount doesn't work, we have to do a select
self.cursor.execute(select_sql + primary_key_sql, tuple([nMap[key] for key in primary_key]))
row_exists = True if self.cursor.fetchone() is not None else False
print select_sql + primary_key_sql % tuple(nMap[key] for key in primary_key)
if row_exists:
#Do an update first, then check rows affected
col_map_list = [x + '=%s' for x in nMap.keys()]
col_map = ', '.join(col_map_list)
print "-----------------------------------------"
print "Row #%d" % i
print "SQL: " + ((update_sql % col_map) + primary_key_sql) % (tuple(nMap.values() + [nMap[key] for key in primary_key]))
self.cursor.execute(
((update_sql % col_map) + primary_key_sql),
tuple(nMap.values() + [nMap[key] for key in primary_key])
)
print "Affected: " + str(self.cursor._source.rows_affected)
else:
#Do the insert if nothing was updated
col_map = ("%s, " * len(nMap)).rstrip(', ')
print "SQL: " + (insert_sql % (", ".join(nMap.keys()), col_map)) % tuple(nMap.values())
try:
self.cursor.execute(
insert_sql % (", ".join(nMap.keys()), col_map),
tuple(nMap.values())
)
except Exception as e:
print e
print "-----------------------------------------"
def commit(self):
return self.database.commit()
def rollback(self):
return self.database.rollback()
class result(object):
def __init__(self, res):
self._result = res
self.into = into(res)
def callback(self, callback):
self.into = into(callback.__call__(self._result))
return self
class into(object):
def __init__(self, result):
self._result = result
def database(self, connection):
return etl(
connection=connection,
result=self._result)
Helpful blog, bookmarked the website with hopes to read more!