
#
#    ========== licence begin LGPL
#    Copyright (C) 2000 SAP AG
#
#    This library is free software; you can redistribute it and/or
#    modify it under the terms of the GNU Lesser General Public
#    License as published by the Free Software Foundation; either
#    version 2.1 of the License, or (at your option) any later version.
#
#    This library is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#    Lesser General Public License for more details.
#
#    You should have received a copy of the GNU Lesser General Public
#    License along with this library; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#    ========== licence end
#

import sys
import string

import loader

class Error:
    def __init__ (self, msg):
        self.msg = msg

    def __str__ (self):
        return '<loaderEx.Error ' + self.msg + '>'

class LoaderSession:
    verbose = None
    quiet = None

    def __init__ (self, host = None, dbname = None, user = None,
            pwd = None, autocommit = 1, logfile = None, sqlmode = None):
        args = [host, dbname, user, pwd]
        filterResult = filter (None,args )
        argsGiven = len (filterResult)
        if argsGiven == 0:
            self.session = self.openBySysArgv ()
        else:
            self.session = self.openByArgs (host, dbname, user, pwd)
            self.args = []
        if logfile:
            self.openLog (logfile)
        if not autocommit:
            self.cmd ('autocommit off')
        if sqlmode:
            self.cmd ('sqlmode ' + sqlmode)

    def openBySysArgv (self):
        args = sys.argv [1:]
        optdesc = [
            ('u', 'user', ':', None, 'user name [<user>,<pwd>]'),
    	    ('d', 'db', ':', None, 'db name'),
            ('n', 'node', ':', '', 'server name'),
            ]
        import optlib
        options, self.args = optlib.parseArgs (optdesc, None, args)
        pos = string.find (options.user, ',')
        if pos == -1:
            user = options.user
            pwd = ''
        else:
            user = options.user [:pos]
            pwd = options.user [pos + 1:]
        return self.openByArgs (options.node, options.db, user, pwd)

    def checkArgs (self, dbname, user, pwd):
        missingArgs = []
        for value, name in [(dbname, 'database name'),
                            (user, 'user name'),
                            (pwd, 'user password')]:
            if not value:
                missingArgs.append (name)
        if missingArgs:
            raise Error ('argument missing: ' + string.join (missingArgs, ', '))

    def openByArgs (self, host, dbname, user, pwd):
        self.checkArgs (dbname, user, pwd)
        session = loader.Loader (host, dbname)
        cmdString = 'use user %(user)s %(pwd)s serverdb %(dbname)s ' % locals ()
        if host:
            cmdString = cmdString + 'on ' + host
        session.cmd (cmdString)
        return session

    def openLog (self, stream):
        if type (stream) == type (''):
            stream = open (stream, 'w')
        self.write = stream.write

    def write (self, text):
        pass

    def log (self, *items):
        self.write (string.join (items) + '\n')

    def cmd (self, cmdstr):
        if self.verbose:
            print cmdstr,
        try:
            result = self.session.cmd (cmdstr)
            self.log (cmdstr, '=> OK', result)
        except loader.LoaderError, err:
            self.log (cmdstr, '=> ERR', str (err))
            raise loader.LoaderError, err, sys.exc_traceback
        except:
            kind, val, traceback = sys.exc_info()
            self.log (cmdstr, '=> unexpected exception', str (kind), str (val))
            raise kind, val, traceback

    def sqlRC (self, cmdstr):
        return self.allowSpecificErrors (0, cmdstr)

    def sqlOK (self, cmdstr):
        rc = self.sqlRC (cmdstr)
        return rc == 0

    def allowSpecificErrors (self, allowedRCs, cmdstr):
        try:
            self.cmd(cmdstr)
            rc = 0
        except loader.LoaderError, err:
            if type (allowedRCs) == type (0):
                allowedRCs = [allowedRCs, 0]
            else:
                allowedRCs = allowedRCs + [0]
            rc = err.sqlCode
            if rc not in allowedRCs:
                raise err  # this is an Loader error
        return rc

    def include (self, *modulelist):
        for file in files:
            if not self.quiet:
                print file
            self.includeOne (file)

    def includeOne (self, fname):
        print "Don't know how to include", fname

    def multiple (self, commands):
        stringType = type ('')
        for element in commands:
            allowedErrors = None
            if type (element) == stringType:
                cmdString = element
            else:
                if len (element) == 0:
                    continue
                elif len (element) == 1:
                    cmdString = element [0]
                else:
                    cmdString = element [0]
                    allowedErrors = element [1]
                    try:
                        # convert single integer to list
                        value = int (allowedErrors)
                        allowedErrors = [value]
                    except ValueError:
                        pass
            if allowedErrors:
                self.allowSpecificErrors (allowedErrors, cmdString)
            else:
                self.cmd (cmdString)


class TeeStream:
    def __init__ (self, streams):
        self.streams = streams

    def write (self, text):
        for stream in self.streams:
            stream.write (text)

