/*
 * Copyright (c) 2003 by SAP AG. All Rights Reserved.
 *
 * SAP, mySAP, mySAP.com and other SAP products and
 * services mentioned herein as well as their respective
 * logos are trademarks or registered trademarks of
 * SAP AG in Germany and in several other countries all
 * over the world. MarketSet and Enterprise Buyer are
 * jointly owned trademarks of SAP AG and Commerce One.
 * All other product and service names mentioned are
 * trademarks of their respective companies.
 *
 * @version $Id$
 */

package com.sapportals.wcm.util.http.slim;

import com.sapportals.wcm.WcmException;
import com.sapportals.wcm.util.http.*;

import java.net.InetAddress;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import javax.crypto.Cipher;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.DESKeySpec;

/**
 * WDNTLMAuthentication is a {@link IWDCredentials} which implements
 * authentication via the Windows NTLM protocol.
 * <ul>
 * </ul>
 * <p>
 *
 * Copyright (c) SAP AG 2001-2003
 *
 * @author stefan.eissing@greenbytes.de
 * @version $Id: WDNTLMAuthentication.java,v 1.3 2003/02/17 14:24:04 jre Exp $
 */
final class WDNTLMAuthentication implements ICredentials {

  private final static com.sap.tc.logging.Location log = com.sap.tc.logging.Location.getLocation(com.sapportals.wcm.util.http.slim.WDNTLMAuthentication.class);

  private final static int STATE_FRESH = 0;
  private final static int STATE_RETRIEVING_CHALLENGE = 1;
  private final static int STATE_CHALLENGE_RECEIVED = 2;
  private final static int STATE_AUTHENTICATION_SENT = 3;

  /**
   * TBD: Description of the class.
   */
  private final static class ConnectionState {
    /**
     * Did we have valid setup information?
     */
    public boolean valid;
    public int state = STATE_FRESH;
    public String hostname;
    public String ntDomain;
    public String message1;
    public String message3;

    public ConnectionState() { }
  }

  /**
   * the context the credentials belongs to
   */
  private final UserInfo m_info;
  private final Map states;

  public WDNTLMAuthentication(UserInfo ui)
    throws WcmException {
    this.m_info = ui;
    this.states = Collections.synchronizedMap(new HashMap(17));
  }

  public String getName() {
    return "NTLM";
  }

  public synchronized boolean apply(IRequester requester, String uri, IRequest request, String headerName)
    throws WcmException {
    ConnectionState cs = getState(requester);
    // No valid setup information supplied, empty credentials
    if (!cs.valid) {
      return false;
    }

    switch (cs.state) {
      case STATE_FRESH:
        cs.valid = false;
        break;
      case STATE_RETRIEVING_CHALLENGE:
        if (cs.message1 != null) {
          request.setHeader(headerName, "NTLM " + cs.message1);
        }
        else {
          endUse(requester);
        }
        break;
      case STATE_CHALLENGE_RECEIVED:
        if (cs.message3 != null) {
          request.setHeader(headerName, "NTLM " + cs.message3);
        }
        else {
          endUse(requester);
        }
        break;
      case STATE_AUTHENTICATION_SENT:
        // need not to send it again
        break;
      default:
        // something wrong in our state machine?
        endUse(requester);
        break;
    }

    return cs.valid;
  }

  /**
   * Set the credentials for this context depending on the header information
   * (received from a 401).
   *
   * @param requester TBD: Description of the incoming method parameter
   * @param header TBD: Description of the incoming method parameter
   * @return TBD: Description of the outgoing return value
   * @exception WcmException Exception raised in failure situation
   */
  public synchronized boolean setup(IRequester requester, String header)
    throws WcmException {
    return setup(requester, header, false);
  }

  /**
   * Set the credentials for this context depending on the header information
   * (received from a 401).
   *
   * @param requester TBD: Description of the incoming method parameter
   * @param header TBD: Description of the incoming method parameter
   * @param retry TBD: Description of the incoming method parameter
   * @return TBD: Description of the outgoing return value
   * @exception WcmException Exception raised in failure situation
   */
  public synchronized boolean setup(IRequester requester, String header, boolean retry)
    throws WcmException {
    ConnectionState cs = getState(requester);
    if (log.beDebug()) {
      log.debugT("setup(145)", "SETUP in state(" + cs.state + "): " + header);
    }

    switch (cs.state) {
      case STATE_FRESH:
        cs.valid = generateMessage1(cs);
        if (cs.valid) {
          cs.state = STATE_RETRIEVING_CHALLENGE;
        }
        break;
      case STATE_RETRIEVING_CHALLENGE:
      {
        cs.message3 = generateMessage3(header, cs.hostname, cs.ntDomain);
        cs.valid = (cs.message3 != null);
        if (cs.valid) {
          cs.state = STATE_CHALLENGE_RECEIVED;
        }
        break;
      }

      case STATE_CHALLENGE_RECEIVED:
//        cs.valid = false;
//        cs.state = STATE_FRESH;
        break;
      case STATE_AUTHENTICATION_SENT:
        cs.valid = true;
        break;
      default:
        // something wrong with out state machine?
        cs.valid = false;
        cs.state = STATE_FRESH;
        break;
    }

    return cs.valid;
  }

  /**
   * Process authenticate information in the response message, like
   * Authenticate-Info
   *
   * @param requester TBD: Description of the incoming method parameter
   * @param response TBD: Description of the incoming method parameter
   * @exception WcmException Exception raised in failure situation
   */
  public synchronized void got(IRequester requester, IResponse response)
    throws WcmException {
    ConnectionState cs = getState(requester);
    if (cs.state == STATE_CHALLENGE_RECEIVED) {
      if (cs.message3 != null) {
        cs.state = STATE_AUTHENTICATION_SENT;
      }
      else {
        //TODO: maybe state failed would be better?
        cs.state = STATE_FRESH;
      }
    }
    String value = response.getHeader("Authentication-Info");
    if (value == null) {
      return;
    }

    // This does nothing yet, but cry out loud if we ever
    // see such a thing
    //
  }

  public void startUse(IRequester requester) {
    this.resetState(requester);
  }

  public void endUse(IRequester requester) {
    this.removeState(requester);
  }

  public boolean canTriggerAuthentication(IRequester requester) {
    return true;
  }

  // ------------------ private / protected ---------------------------

  private ConnectionState getState(IRequester requester) {
    ConnectionState cs = (ConnectionState)this.states.get(requester);
    if (cs == null) {
      cs = new ConnectionState();
      this.states.put(requester, cs);
    }
    return cs;
  }

  private void removeState(IRequester requester) {
    this.states.remove(requester);
  }

  private void resetState(IRequester requester) {
    ConnectionState cs = (ConnectionState)this.states.get(requester);
    if (cs != null) {
      if (cs.valid && cs.message1 != null) {
        cs.state = STATE_RETRIEVING_CHALLENGE;
        cs.message3 = null;
      }
      else {
        cs.state = STATE_FRESH;
        cs.message1 = null;
        cs.message3 = null;
        cs.valid = false;
      }
    }
  }

  private boolean generateMessage1(ConnectionState cs) {
    log.debugT("generateMessage1(256)", "generating message type 1");
    try {
      cs.hostname = InetAddress.getLocalHost().getHostName();
      //TODO: FIXME where do we get our nt domain name?
      cs.ntDomain = "";
      if (log.beDebug()) {
        log.debugT("generateMessage1(262)", "hostname: " + cs.hostname);
      }
      int index = cs.hostname.indexOf('.');
      if (index > 0) {
        if (index < cs.hostname.length() - 1) {
          cs.ntDomain = cs.hostname.substring(index + 1);
        }
        cs.hostname = cs.hostname.substring(0, index);
      }

      cs.message1 = NTLMMessageHandler.getMessage1(cs.hostname, cs.ntDomain);

      return (cs.message1 != null);
    }
    catch (Exception ex) {
      log.warningT("generateMessage1(277)", "determining local host/domain information" + " - " + com.sapportals.wcm.util.logging.LoggingFormatter.extractCallstack(ex));
      return false;
    }
  }

  private String generateMessage3(String header, String hostname, String ntDomain) {
    if (log.beDebug()) {
      log.debugT("generateMessage3(284)", "generating message type 3 from header: " + header);
    }
    try {
      header = header.trim();
      int index = header.lastIndexOf(' ');
      if (index > 0 && index < header.length() - 1) {
        String param = header.substring(index + 1);
        byte[] nonce = NTLMMessageHandler.getNonce(param);
        if (nonce != null) {
          return NTLMMessageHandler.getMessage3(hostname, ntDomain, this.m_info, nonce);
        }
      }

      return null;
    }
    catch (NoSuchAlgorithmException ex) {
      log.warningT("generateMessage3(300)",  "could not obtain cipher for DES/ECB, NTLM authentication not available: " + " - " + com.sapportals.wcm.util.logging.LoggingFormatter.extractCallstack(ex));
      return null;
    }
    catch (Exception ex) {
      log.warningT("generateMessage3(300)", "retrieving nonce from NTLM header" + " - " + com.sapportals.wcm.util.logging.LoggingFormatter.extractCallstack(ex));
      return null;
    }
  }

  /**
   * Extract the contents of a quoted string from line, starting at position
   * start.
   */
//  private static String getUnquoted(String line, int start) {
//    if (line.charAt(start) != '\"') {
//      return null;
//    }
//    StringBuffer sb = new StringBuffer();
//    int len = line.length();
//    for (int i = start+1; i < len; ++i) {
//      char c = line.charAt(i);
//      switch (c) {
//      case '\"':
//        return sb.toString();
//      default:
//        sb.append(c);
//      }
//    }
//    // unterminated quoted string
//    return null;
//  }

  private final static class NTLMMessageHandler {

    private final static byte[] MSG_HEADER = new byte[]{
      'N', 'T', 'L', 'M',
      'S', 'S', 'P', 0,
      };
    private final static int MSG_POS_TYPE = 8;

    private final static byte[] MSG1_TEMPLATE = new byte[]{
      'N', 'T', 'L', 'M',
      'S', 'S', 'P', 0,
      1, 0, 0, 0,
      3, (byte)0xb2, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0x20, 0, 0, 0,
      };

    private final static int MSG1_POS_DOMAIN_OFFSET = 20;
    private final static int MSG1_POS_DOMAIN_LENGTH = 16;
    private final static int MSG1_POS_HOST_LENGTH = 24;

    private final static byte[] MSG3_TEMPLATE = new byte[]{
      'N', 'T', 'L', 'M',
      'S', 'S', 'P', 0,
      3, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0, 0, 0, 0,
      0x01, (byte)0x82, 0, 0,
      };

    private final static int MSG3_POS_LMRESP_LENGTH = 12;
    private final static int MSG3_POS_LMRESP_OFFSET = 16;
    private final static int MSG3_POS_NTRESP_LENGTH = 20;
    private final static int MSG3_POS_NTRESP_OFFSET = 24;
    private final static int MSG3_POS_DOMAIN_LENGTH = 28;
    private final static int MSG3_POS_DOMAIN_OFFSET = 32;
    private final static int MSG3_POS_USER_LENGTH = 36;
    private final static int MSG3_POS_USER_OFFSET = 40;
    private final static int MSG3_POS_HOST_LENGTH = 44;
    private final static int MSG3_POS_HOST_OFFSET = 48;
    private final static int MSG3_MSG_LENGTH = 56;

    private final static String UCS2LE = "UnicodeLittleUnmarked";

    public static String getMessage1(String hostname, String ntDomain)
      throws Exception {
      hostname = hostname.toUpperCase();
      ntDomain = ntDomain.toUpperCase();

      int msgLen = MSG1_TEMPLATE.length + hostname.length() + ntDomain.length();
      byte[] buffer = new byte[msgLen];
      System.arraycopy(MSG1_TEMPLATE, 0, buffer, 0, MSG1_TEMPLATE.length);

      // insert hostname + length
      putShort(buffer, MSG1_POS_HOST_LENGTH, hostname.length());
      putShort(buffer, MSG1_POS_HOST_LENGTH + 2, hostname.length());
      // insert ntdomain + length + offset
      putShort(buffer, MSG1_POS_DOMAIN_LENGTH, ntDomain.length());
      putShort(buffer, MSG1_POS_DOMAIN_LENGTH + 2, ntDomain.length());
      putShort(buffer, MSG1_POS_DOMAIN_OFFSET, hostname.length());

      // insert hostname + ntdomain string values
      putBytes(buffer, MSG1_TEMPLATE.length, hostname.getBytes("ISO-8859-1"));
      putBytes(buffer, MSG1_TEMPLATE.length + hostname.length(), ntDomain.getBytes("ISO-8859-1"));

      return Base64.encode(buffer, Integer.MAX_VALUE);
    }

    public static byte[] getNonce(String base64message2)
      throws Exception {
      byte[] buffer = Base64.decode(base64message2);
      if (buffer.length < 40) {
        if (log.beDebug()) {
          log.debugT("getNonce(413)", "unexpected type 2 message length: " + buffer.length + ", buffer=" + printBuffer(buffer));
        }
        return null;
      }
      if (isNTLMMessage2(buffer)) {
        byte[] nonce = new byte[8];
        System.arraycopy(buffer, 24, nonce, 0, nonce.length);
        return nonce;
      }
      return null;
    }

    public static String getMessage3(String hostname, String ntDomain, UserInfo user, byte[] nonce)
      throws Exception {
      Cipher c = null;
      try {
        c = javax.crypto.Cipher.getInstance("DES/ECB/NoPadding");
      }
      catch (NoSuchAlgorithmException ex) {
            //$JL-EXC$
        // Well, give it a try. Newer releases of IAIK do not know the
        // ECB mode in lookups any longer. Sigh.
        if (log.beInfo()) {
          log.infoT("DES/ECB/NoPadding cipher not found, trying plain DES");
        }
        c = javax.crypto.Cipher.getInstance("DES");
      }
      if (log.beInfo()) {
        log.infoT("DES courtesy of "+ c.getProvider().getName());
      }
      byte[] lm_response = getLMResponse(c, user.getPassword(), nonce);
      byte[] nt_response = getNTResponse(c, user.getPassword(), nonce);

      String userID = user.getUser();
      int index = userID.indexOf('\\');
      if (index < 0) {
        index = userID.indexOf('/');
      }
      if (index > 0 && index < userID.length() - 1) {
        ntDomain = userID.substring(0, index);
        userID = userID.substring(index + 1);
      }

      int msg_length = MSG3_TEMPLATE.length
         + ((ntDomain.length() + userID.length() + hostname.length()) * 2)
         + lm_response.length + nt_response.length;

      byte[] message = new byte[msg_length];
      System.arraycopy(MSG3_TEMPLATE, 0, message, 0, MSG3_TEMPLATE.length);

      int offset = MSG3_TEMPLATE.length;
      putShort(message, MSG3_MSG_LENGTH, message.length);

      offset = putBytes(message, offset, MSG3_POS_DOMAIN_LENGTH, MSG3_POS_DOMAIN_OFFSET,
        ntDomain.getBytes(UCS2LE));
      offset = putBytes(message, offset, MSG3_POS_USER_LENGTH, MSG3_POS_USER_OFFSET,
        userID.getBytes(UCS2LE));
      offset = putBytes(message, offset, MSG3_POS_HOST_LENGTH, MSG3_POS_HOST_OFFSET,
        hostname.getBytes(UCS2LE));
      offset = putBytes(message, offset, MSG3_POS_LMRESP_LENGTH, MSG3_POS_LMRESP_OFFSET, lm_response);
      offset = putBytes(message, offset, MSG3_POS_NTRESP_LENGTH, MSG3_POS_NTRESP_OFFSET, nt_response);

      return Base64.encode(message, Integer.MAX_VALUE);
    }

    private final static byte[] LM_MAGIC = {
      (byte)0x4b, (byte)0x47, (byte)0x53, (byte)0x21, (byte)0x40, (byte)0x23, (byte)0x24, (byte)0x25,
      };

    private static byte[] getLMResponse(Cipher c, String password, byte[] nonce)
      throws Exception {
      byte[] lm_password = new byte[14];
      byte[] lm_hashedpw = new byte[21];
      for (int i = 0, n = password.length(); i < n && i < lm_password.length; ++i) {
        lm_password[i] = (byte)Character.toUpperCase(password.charAt(i));
      }

      Key k = getDESKey(lm_password, 0);
      c.init(Cipher.ENCRYPT_MODE, k);
      c.doFinal(LM_MAGIC, 0, LM_MAGIC.length, lm_hashedpw, 0);

      k = getDESKey(lm_password, 7);
      c.init(Cipher.ENCRYPT_MODE, k);
      c.doFinal(LM_MAGIC, 0, LM_MAGIC.length, lm_hashedpw, 8);

// do not normally log, sensitive information
//      if (log.beDebug()) {
//        log.debugT("LanManager hashed password: "+printBuffer(lm_hashedpw));
//      }
      return calculateResponse(c, lm_hashedpw, nonce);
    }

    private static byte[] getNTResponse(Cipher c, String password, byte[] nonce)
      throws Exception {
      byte[] nt_password = password.getBytes("UTF-16LE");
      byte[] nt_hashedpw = new byte[21];
      MD4 md = MD4.getInstance();
      byte[] hash = md.digest(nt_password);
      System.arraycopy(hash, 0, nt_hashedpw, 0, hash.length);

// do not normally log, sensitive information
//      if (log.beDebug()) {
//        log.debugT("NT hashed password: "+printBuffer(nt_hashedpw));
//      }
      return calculateResponse(c, nt_hashedpw, nonce);
    }

    private static byte[] calculateResponse(Cipher c, byte[] keys, byte[] plain)
      throws Exception {
      byte[] response = new byte[24];

      Key k = getDESKey(keys, 0);
      c.init(Cipher.ENCRYPT_MODE, k);
      c.doFinal(plain, 0, 8, response, 0);

      k = getDESKey(keys, 7);
      c.init(Cipher.ENCRYPT_MODE, k);
      c.doFinal(plain, 0, 8, response, 8);

      k = getDESKey(keys, 14);
      c.init(Cipher.ENCRYPT_MODE, k);
      c.doFinal(plain, 0, 8, response, 16);

      return response;
    }

    private static Key getDESKey(byte[] buffer, int offset)
      throws Exception {
      byte[] key = new byte[8];

      key[0] = buffer[offset];
      key[1] = (byte)(((buffer[offset + 0] << 7) & 0xff) | ((buffer[offset + 1] & 0xff) >> 1));
      key[2] = (byte)(((buffer[offset + 1] << 6) & 0xff) | ((buffer[offset + 2] & 0xff) >> 2));
      key[3] = (byte)(((buffer[offset + 2] << 5) & 0xff) | ((buffer[offset + 3] & 0xff) >> 3));
      key[4] = (byte)(((buffer[offset + 3] << 4) & 0xff) | ((buffer[offset + 4] & 0xff) >> 4));
      key[5] = (byte)(((buffer[offset + 4] << 3) & 0xff) | ((buffer[offset + 5] & 0xff) >> 5));
      key[6] = (byte)(((buffer[offset + 5] << 2) & 0xff) | ((buffer[offset + 6] & 0xff) >> 6));
      key[7] = (byte)((buffer[offset + 6] << 1) & 0xff);

      DESKeySpec ks = new DESKeySpec(key);
      SecretKeyFactory skf = SecretKeyFactory.getInstance("DES");
      return skf.generateSecret(ks);
    }

    private static boolean isNTLMMessage2(byte[] buffer) {
      if (buffer.length >= MSG_HEADER.length) {
        for (int i = 0; i < MSG_HEADER.length; ++i) {
          if (buffer[i] != MSG_HEADER[i]) {
            if (log.beDebug()) {
              log.debugT("isNTLMMessage2(551)", "type 2 message differs at offset " + i + ": got " + buffer[i] + " instead of " + MSG_HEADER[i]);
            }
            return false;
          }
        }
        if (log.beDebug()) {
          log.debugT("isNTLMMessage2(557)", "type 2 message header OK");
        }

        if (buffer.length > MSG_POS_TYPE) {
          return buffer[MSG_POS_TYPE] == 2;
        }
        return false;
      }
      return false;
    }

    /**
     * place int as 2 bytes (short) in little endian order
     *
     * @param buffer TBD: Description of the incoming method parameter
     * @param offset TBD: Description of the incoming method parameter
     * @param value TBD: Description of the incoming method parameter
     */
    private static void putShort(byte[] buffer, int offset, int value) {
      byte b1 = (byte)(value & 0x0FF);
      byte b2 = (byte)(value & 0x0FF00);
      buffer[offset] = b1;
      buffer[offset + 1] = b2;
    }

    private static void putBytes(byte[] buffer, int offset, byte[] value) {
      System.arraycopy(value, 0, buffer, offset, value.length);
    }

    private static int putBytes(byte[] buffer, int offset, int lenOffset, int offsetOffset, byte[] value) {
      putBytes(buffer, offset, value);
      putShort(buffer, lenOffset, value.length);
      putShort(buffer, lenOffset + 2, value.length);
      putShort(buffer, offsetOffset, offset);
      return offset + value.length;
    }

    public static String printBuffer(byte[] buffer) {
      StringBuffer sb = new StringBuffer();
      sb.append("[\n");
      for (int i = 0, col = 0; i < buffer.length; ++i, ++col) {
        if (col == 16) {
          sb.append("\n");
          col = 0;
        }
        if (col != 0) {
          sb.append(", ");
        }
        String hex = Integer.toString((buffer[i] & 0x0ff), 16);
        if (hex.length() == 1) {
          sb.append('0');
        }
        sb.append(hex);
      }
      sb.append("\n]");
      return sb.toString();
    }

  }
}

