/**
 * File: levenshtein.h
 * Author: Thomas Leplus
 * Creation date: 2004-07-19
 * Last modification: Time-stamp: <2004-10-05 12:38:40 leplusth>
 */

#ifndef __LEVENSHTEIN_H__
#define __LEVENSHTEIN_H__

/**
 * The default substitution penalty.
 */
#define __PNLT_SUB   1

/**
 * The default insertion penalty.
 */
#define __PNLT_INS   1

/**
 * The default deletion penalty.
 */
#define __PNLT_DEL   1

/**
 * The default matching penalty.
 */
#define __PNLT_MATCH 0

/**
 * The default substitution character for the string representation of
 * alignments.
 */
#define __CHAR_SUB   'S'

/**
 * The default insertion character for the string representation of
 * alignments.
 */
#define __CHAR_INS   'I'

/**
 * The default deletion character for the string representation of
 * alignments.
 */
#define __CHAR_DEL   'D'

/**
 * The default match character for the string representation of
 * alignments.
 */
#define __CHAR_MATCH 'M'

#include <cstring>
#include <iostream>
#include <string>
#include <vector>

/**
 * The usual minimum macro.
 */
#define MIN(A,B) ((A) < (B) ? (A) : (B))

/**
 * The usual maximum macro.
 */
#define MAX(A,B) ((A) > (B) ? (A) : (B))

using namespace std;

/**
 * A levenshtein distance computing class.
 *
 * An LAlignment object is the result of a alignment computed by an
 * LAligner object. It contains all the useful informations on the
 * best alignment and how it has been computed.
 */
class LAlignment {
  
 private:
  
  /**
   * The actual alignment string representation.
   */
  char *_al;
  
  /**
   * The character that has been used to represent a substitution.
   */
  char _cs;
  
  /**
   * The character that has been used to represent an insertion.
   */
  char _ci;
  
  /**
   * The character that has been used to represent a deletion.
   */
  char _cd;
  
  /**
   * The character that has been used to represent a match.
   */
  char _cm;
  
  /**
   * The penalty that has been used to weight a substitution.
   */
  int _ps;
  
  /**
   * The penalty that has been used to weight an insertion.
   */
  int _pi;
  
  /**
   * The penalty that has been used to weight a deletion.
   */
  int _pd;
  
  /**
   * The penalty that has been used to weight a match.
   */
  int _pm;
  
  /**
   * The actual edit distance corresponding to the best alignment.
   */
  int _dist;

  /**
   * The length of the first object aligned.
   */
  size_t _sa;

  /**
   * The length of the second object aligned.
   */
  size_t _sb;

  /**
   * Returns the number of occurrences of a given character in the
   * alignment string representation.
   */
  size_t get_char_count(char c) const {
    size_t n = 0;
    size_t l = strlen(_al);
    for (size_t i = 0; i < l; i++)
      if (_al[i] == c)
        n++;
    return n;
  }
  
 public:
  
  /**
   * Constructs an new LAlignment object from all the informations
   * computed by the LAligner.
   */
  LAlignment(const char *alignment,
             size_t size_a, size_t size_b,
             int sub = __PNLT_SUB, int ins = __PNLT_INS,
             int del = __PNLT_DEL, int match = __PNLT_MATCH,
             char c_sub = __CHAR_SUB, char c_del = __CHAR_DEL,
             char c_ins = __CHAR_INS, char c_match = __CHAR_MATCH) {
    _al = strdup(alignment);
    _ps = sub;
    _pi = ins;
    _pd = del;
    _pm = match;
    _cs = c_sub;
    _ci = c_ins;
    _cd = c_del;
    _cm = c_match;
    _dist = get_substitute_count() * get_substitute_penalty()
      + get_insert_count() * get_insert_penalty()
      + get_delete_count() * get_delete_penalty()
      + get_match_count() * get_match_penalty();
    _sa = size_a;
    _sb = size_b;
  }
  
  /**
   * Destructs an LAlignment object.
   */
  ~LAlignment() {
    delete _al;
  }
  
  /**
   * Returns the string representation of the alignment.
   */
  string get_alignment_string() const {
    return _al;
  }
  
  /**
   * Returns the length of the string representation of the alignment.
   */
  size_t get_alignment_size() const {
    return strlen(_al);
  }
  
  /**
   * Returns the character at a given position of the alignment string.
   */
  char get_alignment_at(size_t index) const {
    return _al[index];
  }
  
  /**
   * Returns the character used to represent substitutions in the
   * alignment string.
   */
  char get_substitute_char() const {
    return _cs;
  }
  
  /**
   * Returns the number of substitutions in the alignment.
   */
  size_t get_substitute_count() const {
    return get_char_count(_cs);
  }
  
  /**
   * Returns the penalty used to weight substitutions in the
   * alignment.
   */
  int get_substitute_penalty() const {
    return _ps;
  }
  
  /**
   * Returns the character used to represent insertions in the
   * alignment string.
   */
  char get_insert_char() const {
    return _ci;
  }
  
  /**
   * Returns the number of insertions in the alignment.
   */
  size_t get_insert_count() const {
	return get_char_count(_ci);
  }
  
  /**
   * Returns the penalty used to weight insertions in the
   * alignment.
   */
  int get_insert_penalty() const {
    return _pi;
  }
  
  /**
   * Returns the character used to represent deletions in the
   * alignment string.
   */
  char get_delete_char() const {
    return _cd;
  }
  
  /**
   * Returns the number of deletions in the alignment.
   */
  size_t get_delete_count() const {
    return get_char_count(_cd);
  }
  
  /**
   * Returns the penalty used to weight deletions in the
   * alignment.
   */
  int get_delete_penalty() const {
    return _pd;
  }
  
  /**
   * Returns the character used to represent matches in the
   * alignment string.
   */
  char get_match_char() const {
    return _cm;
  }
  
  /**
   * Returns the number of matches in the alignment.
   */
  size_t get_match_count() const {
    return get_char_count(_cm);
  }
  
  /**
   * Returns the penalty used to weight matches in the
   * alignment.
   */
  int get_match_penalty() const {
    return _pm;
  }
  
  /**
   * Returns the edit distance of the alignment.
   */
  int get_distance() const {
    return _dist;
  }

  /**
   * Returns the normalized edit distance of the alignment.
   */
  double get_normalized_distance() const {
    if (_sa == 0 && _sb == 0) return 0;
    int psub = MIN(_ps, _pd + _pi);
    int pindel = _sa > _sb ? _pd : _pi;
    int nsub = MIN(_sa, _sb);
    int nindel = MAX(_sa, _sb) - nsub;
    return (double)_dist/(double)(psub * nsub + pindel * nindel);
  }

  /**
   * Prints the string representation of the alignment.
   */
  friend ostream& operator<< (ostream& output, const LAlignment& alignment) {
    output << alignment._al << " ("
           << alignment._cs << "=" << alignment._ps << ","
           << alignment._ci << "=" << alignment._pi << ","
           << alignment._cd << "=" << alignment._pd << ","
           << alignment._cm << "=" << alignment._pm << ") ["
           << alignment._dist << "]";
    return output;
  }
  
};

/**
 * An alignment engine class using the levensthein distance.
 *
 * The LAligner class computes the best alignment between two seuqence
 * objects of any class (vector, string, array...). All you have to
 * provide is an iterator through the elements of each of those
 * sequences and a binaty predicate for equality between any elements
 * of each sequence.
 */
template <class InputIterator1, class InputIterator2, class BinaryPredicate>
class LAligner {
  
 private:
  
  /**
   * This is the type of a cell in the workspace matrix.
   */
  typedef struct {
    char back;
    int dist;
  } LCell;
  
  /**
   * This is the type of workspace matrix.
   */
  typedef vector< vector<LCell> > LMatrix;
    
  /**
   * This is the workspace matrix.
   */
  LMatrix _matrix;

  /**
   * This is the workspace matrix's current height.
   */
  size_t _h;

  /**
   * This is the workspace matrix's current width.
   */
  size_t _w;

  
  /**
   * The character that will be used to represent a substitution.
   */
  char _cs;
  
  /**
   * The character that will be used to represent an insertion.
   */
  char _ci;
  
  /**
   * The character that will be used to represent a deletion.
   */
  char _cd;
  
  /**
   * The character that will be used to represent a match.
   */
  char _cm;
  
  /**
   * The penalty that will be used to weight a substitution.
   */
  int _ps;
  
  /**
   * The penalty that will be used to weight an insertion.
   */
  int _pi;
  
  /**
   * The penalty that will be used to weight a deletion.
   */
  int _pd;
  
  /**
   * The penalty that will be used to weight a match.
   */
  int _pm;
  
  /**
   * Allocates the memory for a new workspace matrix.
   */
  void matrix_alloc(size_t height, size_t width) {
    _matrix.resize(width);
    for (size_t i = 0; i < width; i++) {
      _matrix[i].resize(height);
    }
    _h = height;
    _w = width;
  }
  
  /**
   * Free the memory of the workspace matrix.
   */
  void matrix_free() {
    for (size_t i = 0; i < _w; i++) {
      _matrix[i].clear();
    }
    _matrix.clear();
    _h = 0;
    _w = 0;
  }
  
 public:
  
  /**
   * Constructs a new alignment engine. You can specify the initial
   * dimensions of the workspace matrix. The matrix will automatically
   * grow according to your needs anyway.
   */
  LAligner(size_t init_height = 0, size_t init_width = 0,
           int sub = __PNLT_SUB, int ins = __PNLT_INS,
           int del = __PNLT_DEL, int match = __PNLT_MATCH,
           char c_sub = __CHAR_SUB, char c_del = __CHAR_DEL,
           char c_ins = __CHAR_INS, char c_match = __CHAR_MATCH) {
    matrix_alloc(init_height, init_width);
    _ps = sub;
    _pi = ins;
    _pd = del;
    _pm = match;
    _cs = c_sub;
    _ci = c_ins;
    _cd = c_del;
    _cm = c_match;
  }
  
  /**
   * Destructs the alignment engine.
   */
  ~LAligner() {
    matrix_free();
  }
  
  /**
   * Forces the engine to resize the workspace matrix. This can be
   * useful if you aligned unusually long sequences and the size of the
   * workspace matrix may have grown out of proportions.
   */
  void resize(size_t height, size_t width) {
    matrix_alloc(height, width);
  }
  
  /**
   * Returns the penalty that will be used to weight substitutions.
   */
  int get_substitute_penalty() const {
    return _ps;
  }
  
  /**
   * Changes the penalty that will be used to weight substitutions.
   */
  void set_substitute_penalty(int sub) {
    _ps = sub;
  }
  
  /**
   * Returns the character that will be used to represent substitutions.
   */
  char get_substitute_char() const {
    return _cs;
  }
  
  /**
   * Changes the character that will be used to represent substitutions.
   */
  void set_substitute_char(char c_sub) {
    _cs = c_sub;
  }
  
  /**
   * Returns the penalty that will be used to weight insertions.
   */
  int get_insert_penalty() const {
    return _pi;
  }
  
  /**
   * Changes the penalty that will be used to weight insertions.
   */
  void set_insert_penalty(int ins) {
    _pi = ins;
  }
  
  /**
   * Returns the character that will be used to represent insertions.
   */
  char get_insert_char() const {
    return _ci;
  }
  
  /**
   * Changes the character that will be used to represent insertions.
   */
  void set_insert_char(char c_ins) {
    _ci = c_ins;
  }
  
  /**
   * Returns the penalty that will be used to weight deletions.
   */
  int get_delete_penalty() const {
    return _pd;
  }
  
  /**
   * Changes the penalty that will be used to weight deletions.
   */
  void set_delete_penalty(int del) {
    _pd = del;
  }
  
  /**
   * Returns the character that will be used to represent deletions.
   */
  char get_delete_char() const {
    return _cd;
  }
  
  /**
   * Changes the character that will be used to represent deletions.
   */
  void set_delete_char(char c_del) {
    _cd = c_del;
  }
  
  /**
   * Returns the penalty that will be used to weight matches.
   */
  int get_match_penalty() const {
    return _pm;
  }
  
  /**
   * Changes the penalty that will be used to weight matches.
   */
  void set_match_penalty(int match) {
    _pm = match;
  }
  
  /**
   * Returns the character that will be used to represent matches.
   */
  char get_match_char() const {
    return _cm;
  }
  
  /**
   * Changes the character that will be used to represent matches.
   */
  void set_match_char(char c_match) {
    _cm = c_match;
  }
  
  /**
   * Computes an alignment.
   */
  LAlignment align(InputIterator1 x_begin, InputIterator1 x_end,
                   InputIterator2 y_begin, InputIterator2 y_end,
                   BinaryPredicate equal_to) {
    vector<typename iterator_traits<InputIterator1>::value_type>
      x(x_begin, x_end);
    vector<typename iterator_traits<InputIterator2>::value_type>
      y(y_begin, y_end);
    size_t w = x.size() + 1;
    size_t h = y.size() + 1;
    if (h > _h || w > _w)
      resize(h, w);
    _matrix[0][0].dist = _pm;
    _matrix[0][0].back = _cm;
    for (size_t i = 1; i < w; i++) {
      _matrix[i][0].dist = i * _pi;
      _matrix[i][0].back = _ci;
    }
    for (size_t j = 1; j < h; j++) {
      _matrix[0][j].dist = j * _pd;
      _matrix[0][j].back = _cd;
    }
    for (size_t i = 1; i < w; i++) {
      for (size_t j = 1; j < h; j++) {
        if (equal_to(x[i - 1], y[j - 1])) {
          _matrix[i][j].dist = _matrix[i - 1][j - 1].dist + _pm;
          _matrix[i][j].back = _cm;
        } else {
          _matrix[i][j].dist = _matrix[i - 1][j - 1].dist + _ps;
          _matrix[i][j].back = _cs;
        }
        if (_matrix[i][j].dist > _matrix[i - 1][j].dist + _pi) {
          _matrix[i][j].dist = _matrix[i - 1][j].dist + _pi;
          _matrix[i][j].back = _ci;
        } 
        if (_matrix[i][j].dist > _matrix[i][j - 1].dist + _pd) {
          _matrix[i][j].dist = _matrix[i][j - 1].dist + _pd;
          _matrix[i][j].back = _cd;
        }
      }
    }
    string alignment;
    w--;
    h--;
    while (h > 0 || w > 0) {
      char c = _matrix[w][h].back;
      alignment.insert(alignment.begin(), 1, c);
      if (c == _cm || c == _cs) {
        w--;
        h--;
      } else if (c == _ci) {
        w--;
      } else if (c == _cd) {
        h--;
      }
    }
    return LAlignment(alignment.c_str(), x.size(), y.size(),
                      _ps, _pi, _pd, _pm, _cs, _ci, _cd, _cm);
  }
  
};

/**
 * Computes an alignment.
 */
template <class InputIterator1, class InputIterator2, class BinaryPredicate>
LAlignment align(InputIterator1 x_begin, InputIterator1 x_end,
                 InputIterator2 y_begin, InputIterator2 y_end,
                 BinaryPredicate equal_to,
                 int sub = __PNLT_SUB, int ins = __PNLT_INS,
                 int del = __PNLT_DEL, int match = __PNLT_MATCH,
                 char c_sub = __CHAR_SUB, char c_del = __CHAR_DEL,
                 char c_ins = __CHAR_INS, char c_match = __CHAR_MATCH) {
  LAligner<InputIterator1, InputIterator2, BinaryPredicate>
    aligner(0, 0, sub, ins, del, match, c_sub, c_del, c_ins, c_match);
  return aligner.align(x_begin, x_end, y_begin, y_end, equal_to);
}

#endif
