#!/usr/bin/python
"""rundotsql.py run .sql files on the command line
Peter Bengtsson, Fry-IT Ltd, <peter@fry-it.com>, Oct 2004

USAGE: ./rundotsql.py OPTIONS SQLFILE

Options:
    -t, --test         Rollsback the SQL execution
    -U <username>      Username for connecting with
    <database>         Database to try it on
    -v, --verbose      More verbose info such as the SQL statement
    --version          Prints version and exits
    
    The database connection parameters are stored in ~/.rundotsql.conf
    after the first time you have used it.

Examples:
    $ ./rundotsql.py -U peterbe quietdays SQLUpdateWallet.sql
    $ ./rundotsql.py --test SQLUpdateWallet.sql
"""

#
# CHANGES
#
# 0.6          Added support for <dtml-comment>
# 0.5          tmp files not saved as .sql
# 0.4          better checking on missmatching parameters
#

__version__='0.6'

import sys,os,re,string,time,random

params_regex = re.compile('(<params>(.*?)</params>)', re.DOTALL|re.MULTILINE)
sqlvars_regex = re.compile('(<dtml-sqlvar (.*?) type="(int|string|float)">)', re.DOTALL|re.MULTILINE)
dtmlvars_regex = re.compile('(<dtml-var (.*?)>)', re.DOTALL|re.MULTILINE)
dtmlifs_regex = re.compile('(<dtml-if (.*?)>(.*?)</dtml-if>)', re.DOTALL|re.MULTILINE)
dtmlcomments_regex = re.compile('(<dtml-comment>(.*?)</dtml-comment>)', re.DOTALL|re.MULTILINE)

conffile = '.rundotsql.conf'
conffile_home = os.path.join(os.path.expanduser('~'), conffile)
	
def sql_quote(v):
    find = string.find
    join = string.join
    split = string.split
    if find(v,"\'") >= 0:
	v=join(split(v,"\'"),"''")
    return "'%s'" % v

def _parseParams(params):
    data = {}
    for each in params.split():
	_default = ""
	if len(each.split('=')) ==2:
	    key, _default = [x.strip() for x in each.split('=')]
	    if _default.startswith('"') and _default.endswith('"'):
		_default = _default[1:-1]
	    if _default.startswith("'") and _default.endswith("'"):
		_default = _default[1:-1]
	else:
	    key = each
	if _default:
	    value = raw_input("%s=%r "%(key, _default))
	    if not value:
		value = _default
	    
	else:
	    value = raw_input("%s= "%key)
	    
	if not value:
	    raise "NoValue", "Value for key %r must be set"%key

	if _default and value == 'None':
	    value = None
	
	data[key] = value
    return data

class MissingParamError(Exception):
    def __init__(self, args=None):
	if args:
	    sys.stderr.write("MissingParamError: Missing parameter %s\n"%args)
	self.args = args
	
def _getSQL(file):

    sql = open(file).read()
    
    comments = dtmlcomments_regex.findall(sql)
    for comment in comments:
	outer, inner = comment
	sql = sql.replace(outer, '').strip()

    params = params_regex.findall(sql)
    if params:
	outer, inner = params[0]
	sql = sql.replace(outer,'').strip()
	params = _parseParams(inner)
    else:
	params = {}

    dtmlifs = dtmlifs_regex.findall(sql)
    for dtmlif in dtmlifs:
	outer, key, inner = dtmlif
	value = not not params[key]
	if value:
	    sql = sql.replace(outer, inner)
	else:
	    sql = sql.replace(outer,'')
	    
    sqlvars = sqlvars_regex.findall(sql)
    for sqlvar in sqlvars:
	code, name, typ = sqlvar
	try:
	    value = params[name]
	except KeyError:
	    raise MissingParamError, name
	    
	if typ=='int':
	    value = int(value)
	elif typ=='float':
	    value =float(value)
	elif typ=='string':
	    value = sql_quote(value)
	sql = sql.replace(code, str(value))
	
	
    dtmlvars = dtmlvars_regex.findall(sql)
    for dtmlvar in dtmlvars:
	code, name = dtmlvar
	value = params[name]
	sql = sql.replace(code, str(value))	
    
    sql = sql.strip()
    
    if not sql.endswith(';'):
	sql = sql+';'
	
    return sql
    
def _getRandomSQLfile():
    t=list(str(time.time()).replace('.',''))
    random.shuffle(t)
    t="tmpsqlfile-%s.tmpsql"%''.join(t)
    return t

def run(file, db_connection_string, commit=True, verbose=False):
    sql = _getSQL(file)
    if verbose:
	print "====  SQL "+"="*50
	print sql
	print "==== /SQL "+"="*50
	
    if not commit:
	sql = "Begin;\n%s\nRollback;"%sql
	if verbose:
	    print "Do NOT commit the SQL execution"
	
    sqlfile = _getRandomSQLfile()
    open(sqlfile,'w').write(sql)
    cmd = 'psql %s -f %s'%(db_connection_string, sqlfile)
    #print cmd
    out = os.popen4(cmd)[1].read()
    delete_tmp_file = 1
    if out.find('ERROR:')>-1:
	out = ':'.join(out.split(':')[2:])
	delete_tmp_file = 0
    print out
    
    if delete_tmp_file:
	os.remove(sqlfile)

def grr():
    print __doc__
    sys.exit(1)
    
def _parseConfFile(file):
    data = open(file).read().strip()
    return data

if __name__=='__main__':
    args = sys.argv[1:]
    if not args:
	grr()
    else:
	file=None
	commit=1
	verbose=0
	checked_args = []
	for arg in args:
	    if os.path.isfile(arg):
		file = arg
	    elif arg in ('--test','-t'):
		commit = 0
	    elif arg in ('--version',):
		print __version__
		sys.exit(1)
	    elif arg in ('-v','--verbose'):
		verbose = 1
	    else:
		checked_args.append(arg)
		
	args = checked_args
	if not file:
	    grr()
	
	db_connection_string=None
	if args:
	    db_connection_string = ' '.join(args)
	    open(conffile_home,'w').write(db_connection_string)
	else:
	    # look for .rundotsql.conf file
	    if os.path.isfile(conffile_home):
		db_connection_string = _parseConfFile(conffile_home)
	    else:
		grr()

	run(file, db_connection_string, commit,
	    verbose=verbose)