package com.sap.security.ssl;

/**
 * Title:        User management
 * Description:
 * Copyright:    Copyright (c) 2001
 * Company:      SAPPortals
 * @author KAi Ullrich
 * @version 1.0
 */

import java.net.URL;

import java.util.Enumeration;
import java.util.Properties;
import java.util.Vector;
import java.util.Hashtable;

import java.io.IOException;
import java.io.File;
import java.io.FileInputStream;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.ByteArrayInputStream;


import java.net.URLStreamHandler;
import java.net.MalformedURLException;

import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.GeneralSecurityException;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.security.Principal;

import com.sap.tc.logging.Location;
import com.sap.tc.logging.Severity;

import iaik.security.ssl.SSLClientContext;
import iaik.security.ssl.SSLContext;
import iaik.protocol.https.HttpsURLConnection;
import iaik.protocol.https.HttpsURLStreamHandlerFactory;
import iaik.security.ssl.ChainVerifier;
import iaik.security.ssl.SSLTransport;
import iaik.security.provider.IAIK;

public class HttpsConnectionFactory
{
    /** The default SSLContext. Overwrite it if you want to use a special
     *  default trust store, for instance.
     */
    protected static SSLContext defaultSSLContext;
    private   Properties config;
    private   KeyStore   keystore;
    private   Vector     trustedCerts = null;
    private   String     keyAlias = null;
    private   String     useClientCert = null;
    private   String     keystorepass = null;
    private   String     debugon    = null;
    static private Location myLoc = Location.getLocation(HttpsConnectionFactory.class); 

    private   static     HttpsURLStreamHandlerFactory urlStreamHandlerFactory;

    /** Gets the recommended URLStreamHandler instance. You can also
     *  specify the URLStreamHandlerFactory to be used by the
     *   -Djava.protocol.handler.pkgs system property or
     *  {@link URL#setURLStreamHandlerFactory()}. In both cases you
     *  set the handler VM-wide and you may interfere with other
     *  applications that run in the same VM. To avoid these problems,
     *  use the constructor {@link URL(String protocol, String host, int port, String file, URLStreamHandler handler)}
     *  just as in the following example:
     *  <PRE>
     *    URL url = new URL (protocol, host, port, file, HttpsConnectionFactory.getURLStreamHandler ());
     *    HttpsConnectionFactory factory = new HttpsConnectionFactory (some_properties_object);
     *    HttpsURLConnection con = factory.getConnection (url);
     *    con.connect ();
     *    ...
     *  </PRE>
     *  @return an https URLStreamHandler instance
     */
    public static URLStreamHandler getURLStreamHandler ()
    {
        return urlStreamHandlerFactory.createURLStreamHandler ("https");
    }

    /** Sets a special SSL context as the default SSLContext to be used.
     *  You can use this to implement a special behaviour in things like
     *  trusted server certificates, for instance.
     *  @param context the default SSL context to be used.
     */
    public static void setDefaultSSLContext (SSLContext context)
    {
        defaultSSLContext = context ;
    }

    /** Opens an HttpsURLConnection for a URL object.
     *  Calls {@link URL#openConnection()} and initializes the resulting
     *  {@link HttpsURLConnection} object with the SSL context.
     *  @param url URL object representing the connection data
     *  @return HttpsURLConnection object with an SSL client context
     *
     */
    public HttpsURLConnection getConnection (URL url)
        throws IOException, KeyStoreException
    {
        if (url.getProtocol().equals ("https")==false)
            throw new IllegalStateException ("Only allowd for https URLs");

        HttpsURLConnection basicConnection = (HttpsURLConnection) url.openConnection ();
        basicConnection.setSSLContext (getSSLClientContext());

        return basicConnection ;
    }

    /** Constructor. Specify a properties object with the following parameters:
     *  <table>
     *   <tr><td><i>Parameter name</i></td><td><i>Meaning</i></td><td><i>Default</i></td></tr>
     *   <tr><td>ssl.keystoretype</td><td>Type of the keystore (IAIK, SUN)</td><td>SUN</td></tr>
     *   <tr><td>ssl.keystore</td><td>file name of the keystore</td><td><i>none</i> (must be specified)</td></tr>
     *   <tr><td>ssl.truststore</td><td>file name of the truststore</td><td>${JAVA_HOME}/lib/security/cacerts</td></tr>
     *   <tr><td>ssl.clientcert</td><td>whether a client certificate is to be used or not (0, 1)</td><td>0</td></tr>
     *   <tr><td>ssl.keystorepass</td><td>password of the keystore</td><td></td></tr>
     *   <tr><td>ssl.debugon</td><td>Debug mode (prints debug stuff on stdout)</td><td>0</td></tr>
     *  </table>
     *
     */
    public HttpsConnectionFactory (Properties config)
        throws KeyStoreException, IOException
    {
        this.config = config;

        String keytype = config.getProperty ("ssl.keystoretype");
        String keystore_f = config.getProperty ("ssl.keystore");
        String trustStore = config.getProperty ("ssl.truststore");
               useClientCert = config.getProperty ("ssl.clientcert");
               //new decryption of the keystore password
               keystorepass = config.getProperty ("ssl.keystorepass").substring(1);
               keystorepass = new String(new sun.misc.BASE64Decoder().decodeBuffer(keystorepass));
               debugon    = config.getProperty ("ssl.debugon");


        char [] trustStorePass = null;
        try {
            if (keystore_f!=null) {
                if (keytype==null || keytype.equals ("SUN")) {
                    keystore = KeyStore.getInstance ("JKS");
                }
                else if (keytype.equals ("IAIK")) {
                    IAIK.addAsProvider();
                    keystore = KeyStore.getInstance ("IAIKKeyStore", "IAIK");
                }
                else {
                    keystore = null;
                }
                if (keystore!=null)
                    keystore.load (new FileInputStream (keystore_f),
                                                        makePwFromKeystorePass (keystorepass));

            }

            if (trustStore==null) {
                String s1 = System.getProperty ("java.home") + File.separator +
                            "jre" + File.separator + "lib" + File.separator +
                            "security" + File.separator + "cacerts";
                String s2 = System.getProperty ("java.home") + File.separator +
                            "lib" + File.separator +
                            "security" + File.separator + "cacerts";

                if (true == (new File (s1)).exists ()) {
                    trustStore = s1;
                }
                else {
                    trustStore = s2;
                }
            }
            if (trustStore!=null) {
                KeyStore trustStore_ = KeyStore.getInstance ("JKS");
                trustStore_.load (new FileInputStream (trustStore), new char [] { 'c', 'h',
                                                                                  'a', 'n',
                                                                                  'g', 'e',
                                                                                  'i', 't' });

                Enumeration allCerts = trustStore_.aliases() ;
                trustedCerts = new Vector (10);
                while (true) {
                    if (!allCerts.hasMoreElements())
                        break;

                    String alias = (String)allCerts.nextElement ();

                    if (true == trustStore_.isCertificateEntry (alias)) {
                        trustedCerts.add (new iaik.x509.X509Certificate (trustStore_.getCertificate(alias).getEncoded()));
                    }
                }
            }
            keyAlias = getKeyAlias() ;
        }
        catch (GeneralSecurityException gse) {
            gse.printStackTrace() ;
            throw new KeyStoreException (gse.toString());
        }
        catch (IOException ioe) {
            ioe.printStackTrace();
            throw new KeyStoreException (ioe.toString());
        }
    }

    /** Gets the client context according to the parameters in the properties object.
     *  @return SSLClientContext object
     */
    protected SSLClientContext getSSLClientContext ()
        throws KeyStoreException
    {
        SSLClientContext clientContext = new SSLClientContext ();
//        clientContext.

        if (keyAlias!=null && useClientCert.equals("1")) {
            Certificate certs [] = null;
            java.security.PrivateKey key = null;
            try {
                certs = (Certificate [])
                         keystore.getCertificateChain(keyAlias);
                key = (java.security.PrivateKey) keystore.getKey (keyAlias,
                                                  makePwFromKeystorePass (keystorepass));
            }
            catch (GeneralSecurityException gkse) {
                gkse.printStackTrace ();
                throw new KeyStoreException ("" + gkse);
            }
            try {
                int i = certs.length;
                //CertificateFactory cf = CertificateFactory.getInstance ("X.509");
                iaik.x509.X509Certificate x509Certs [] = new iaik.x509.X509Certificate [i];

                for (i=0; i<certs.length; i++) {
                    x509Certs[i] = (iaik.x509.X509Certificate)
                                   new iaik.x509.X509Certificate (
                                   new ByteArrayInputStream (certs[i].getEncoded())
                                   );
                }

                if (x509Certs[0].getSubjectDN().toString().equals (x509Certs[0].getIssuerDN().toString())) {
                    revertCertOrder (x509Certs);
                }
                clientContext.addClientCredentials (x509Certs, key);
            }
            catch (GeneralSecurityException gse) {
                gse.printStackTrace();
                throw new KeyStoreException (gse.toString() );
            }
            catch (IOException ioe) {
                ioe.printStackTrace();
                throw new KeyStoreException (ioe.toString());
            }
        }

        int i;
        //ChainVerifier c = new NullChainVerifier ();
        ChainVerifier c = clientContext.getChainVerifier() ;
        for (i=0; i<trustedCerts.size(); i++) {

            c.addTrustedCertificate ((java.security.cert.X509Certificate)
                                                  trustedCerts.elementAt(i));
            /*
            System.out.println("Adding trusted: " +
                               ((java.security.cert.X509Certificate)
                                                  trustedCerts.elementAt(i))
                                                  .getSubjectDN().toString()); */
        }
        clientContext.setChainVerifier (c);
        if ("1".equals (debugon))
            clientContext.setDebugStream (System.out);

        Principal [] ps = clientContext.getChainVerifier().getTrustedPrincipalsArray();

        /*
        for (int i__=0; i__<ps.length; i__++) {
            System.out.println("Principal " + i__ + " = " + ps[i__].toString());
        }*/

        return clientContext ;
    }

    static {
        urlStreamHandlerFactory = new HttpsURLStreamHandlerFactory();
    }

    private String getKeyAlias ()
        throws KeyStoreException
    {
        if (keystore==null)
            return null;

        Enumeration aliases = keystore.aliases() ;

        while (true) {
            if (!aliases.hasMoreElements())
                break;

            String alias = (String)aliases.nextElement() ;

            if (keystore.isKeyEntry (alias))
                return alias;
        }
        return null;
    }

    private static char [] makePwFromKeystorePass (String configEntry)
    {
        //to be changed
        return configEntry.toCharArray();
    }

    /** main method for test purposes. Call main for a short description
     *  of the usage of the program.
     */
    public static void main (String args [])
    {
        if (args.length<2) {
            System.out.println("Usage:");
            System.out.println(" java com.sap.security.ssl.HttpsConnectionFactory <location of config file> "+
                               "<URL> [-H<http-header>=<http-header-value> [-H<http-header>=... ]]");

        }
        try {
            URL url = parseURLAndCreateObject (args[1], getURLStreamHandler ());
            Properties p = new Properties ();
            p.load (new FileInputStream (args[0]));
            BufferedReader r = null;
            Hashtable hm = getHeaders (args);
            HttpsConnectionFactory factory = new HttpsConnectionFactory (p);

            HttpsURLConnection con = factory.getConnection (url);
            con.setRequestMethod ("GET");

            Enumeration e = hm.elements();
            while (e.hasMoreElements()) {
                String key = (String)e.nextElement();
                con.setRequestProperty (key, (String)hm.get(key));
            }
            con.connect ();
            int rc = con.getResponseCode ();

            // Hier passiert was ganz wichtiges.
            r = new BufferedReader (new InputStreamReader (con.getInputStream()));
            System.out.println("\n>>> Response code ist: " + rc +" \n");
            System.out.println("" + con.getHeaderField ("set-cookie"));
            String line = r.readLine ();
            while (line!=null) {
                System.out.println(line);
                line = r.readLine ();
            }
        }
        catch (Exception e) {
            e.printStackTrace ();
            System.out.println("\n>>>Exiting because of Exception.");
            System.exit (1);
        }
    }

    private static Hashtable getHeaders (String args [])
    {
        int     i         = 0;
        Hashtable hm2return = new Hashtable (5);;

        for (i=0; i<args.length; i++) {
            if (args[i].startsWith("-H")) {
                String toParse = args[i].substring(2);
                try {
                    hm2return = addHeaderToMap (toParse, hm2return);
                }
                catch (Exception e) {
                	myLoc.traceThrowableT(Severity.ERROR, "getHeaders", e);
                }
            }
        }
        return hm2return ;
    }

    private static Hashtable addHeaderToMap (String keyValue, Hashtable hm)
    {
        int idx = keyValue.indexOf ((int)'=');
        String key = keyValue.substring (0, idx);
        String value = keyValue.substring (idx+1);
        System.out.println("Adding " + key + " = " + value);
        if (hm==null)
            hm=new Hashtable (5);
        hm.put (key, value);
        return hm;
    }

    private static void revertCertOrder (Object [] os)
    {
        int i=0;
        int j=os.length;

        Object []os2 = new Object [j];

        for (i=0; i<j; i++) {
            os2 [j-i-1] = os [i];
        }

        for (i=0; i<j; i++) {
            os [i] = os2[i];
        }
    }

    private static URL parseURLAndCreateObject (String urlString, URLStreamHandler fac)
        throws MalformedURLException
    {
        String  protocol = null,
                host     = null,
                file     = null;
        int     port     = 443;
        int     iIdx,
                iIdx2    = -1;

        if ((iIdx = urlString.indexOf ("://"))==-1) {
            throw new MalformedURLException ("No valid URL.");
        }

        protocol = urlString.substring (0, iIdx);
        //Find index of : and /
        iIdx = urlString.indexOf (":", iIdx+3) ;
        iIdx2= urlString.indexOf ("/", protocol.length()+3) ;

        if (iIdx==-1 && iIdx2==-1) {
            // URL form https://host.domain.country
            // Append SLASH and everything's fine.
            host = urlString.substring (protocol.length() + 3);
            file = "/";
        }
        else if (iIdx==-1) {
            // URL form https://host.domain.country/index.html
            host = urlString.substring (protocol.length() + 3, iIdx2);
            file = urlString.substring (iIdx2);
        }
        else if (iIdx2==-1) {
            // URL form https://host.domain.country:1800
            host = urlString.substring (protocol.length() + 3, iIdx);
            try {
                port = Integer.parseInt (urlString.substring (iIdx+1));
            }
            catch (NumberFormatException n) {
                throw new MalformedURLException ("No valid URL.");
            }

            file = "/";
        }
        else {
            // both are not == -1;
            if (iIdx2 < iIdx) {
                // URL form https://host.domain.country/index.html
                host = urlString.substring (protocol.length() + 3, iIdx);
                file = urlString.substring (iIdx);
            }
            else {
                host = urlString.substring (protocol.length() + 3, iIdx);
                try {
                    port = Integer.parseInt (urlString.substring (iIdx+1, iIdx2));
                }
                catch (NumberFormatException n) {
                    throw new MalformedURLException ("No valid URL.");
                }
                file = urlString.substring (iIdx2); // this time not iIdx2 + 1 since we want the slash
            }
        }

        System.out.println("URL is: " + protocol + "://" + host + ":" + port + file);

        return new URL (protocol, host, port, file, getURLStreamHandler ());
    }
}

class NullChainVerifier extends ChainVerifier
{
    public boolean verifyChain (X509Certificate [] certs, SSLTransport transport)
    {
        int i;

        System.out.println("Chain to verify:");
        for (i=0; i<certs.length; i++) {
            System.out.println("cert [" + i + "]");
            System.out.println(" Subject: " + certs[i].getSubjectDN().toString());
            System.out.println(" Issuer: " + certs[i].getIssuerDN().toString());
            System.out.println(" Serial: " + certs[i].getSerialNumber().toString(0x10));
            System.out.println(" Not before: " + certs[i].getNotBefore().toString());
            System.out.println(" Not after: " + certs[i].getNotAfter().toString());
            // Haette gerne Fingerprint gehabt, aber gut.
            //certs[i].
        }

        return true;
    }
}