# include "hugin"

# include <vector>
# include <string>
# include <cstdio>
# include <iostream>
# include <exception>

using namespace HAPI;
using namespace std;

class EM {
public:
  EM (const string &fileName);

private:
  void specifyLearningParameters (Domain *d);
  void printLearningParameters (Domain *d);
  void loadCases (Domain *d);
  void printCases (Domain *d);
  void printNodeMarginals (Domain *d);
};


int main (int argc, char *argv[])
{
  if (argc != 2) {
    cerr << "Usage: " << argv[0] << " <net_file>\n";
    return -1;
  }

  new EM (string (argv[1]));

  return 0;
}


EM::EM (const string &fileName)
{
  string netFileName = fileName + ".net";
  Domain d (netFileName, NULL);

  string logFileName = fileName + ".log";
  FILE *logFile = fopen (logFileName.c_str (), "w");
  d.setLogFile (logFile);

  d.compile ();

  specifyLearningParameters (&d);
  printLearningParameters (&d);

  loadCases (&d);
  printCases (&d);

  d.learnTables ();

  cout << "Log likelihood: " << d.getLogLikelihood () << endl;

  printNodeMarginals (&d);

  d.saveAsNet ("q.net");
}


void EM::specifyLearningParameters (Domain *d)
{
  const NodeList nl = d->getNodes ();

  NumberList data;

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *node = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (node != 0) {
      Table *table = node->getExperienceTable ();

      data.clear ();
      data.insert (data.end (), table->getSize (), 1);

      table->setData (data);
    }
  }

  d->setLogLikelihoodTolerance (0.000001);
  d->setMaxNumberOfEMIterations (1000);
}


void EM::printLearningParameters (Domain *d)
{
  const NodeList nl = d->getNodes ();

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (dcNode != 0) {
      cout << dcNode->getLabel () << " (" << dcNode->getName () << "):\n";

      cout << "   ";
      if (dcNode->hasExperienceTable ()) {
	Table *table = dcNode->getExperienceTable ();
	NumberList data = table->getData ();
	size_t tblSize = table->getSize ();

	for (size_t i = 0; i < tblSize; i++)
	  cout << data[i] << " ";

	cout << endl;
      }
      else
	cout << "No experience table\n";

      cout << "   ";
      if (dcNode->hasFadingTable ()) {
	Table *table = dcNode->getFadingTable ();
	NumberList data = table->getData ();
	size_t tblSize = table->getSize ();

	for (size_t i = 0; i < tblSize; i++)
	  cout << data[i] << " ";

	cout << endl;
      }
      else
	cout << "No fading table\n";
    }
  }

  cout << "Log likelihood tolerance: " << d->getLogLikelihoodTolerance ()
       << endl;
  cout << "Max EM iterations: " << d->getMaxNumberOfEMIterations () << endl;
}


void EM::loadCases (Domain *d)
{
  d->setNumberOfCases (0);

  size_t iCase = d->newCase ();
  cout << "Case index: " << iCase << endl;

  d->setCaseCount (iCase, 2.5);

  const NodeList nl = d->getNodes ();

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (dcNode != 0)
      dcNode->setCaseState (iCase, 0);
  }

  DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (nl[1]);
  if (dcNode != 0)
    dcNode->unsetCase (iCase);
}


void EM::printCases (Domain *d)
{
  const NodeList nl = d->getNodes ();

  size_t nCases = d->getNumberOfCases ();

  cout << "Number of cases: " << nCases << endl;

  for (size_t i = 0; i < nCases; i++) {
    cout << "case " << i << " " << d->getCaseCount (i) << " ";

    for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
	 nlIter != nlEnd; ++nlIter) {
      DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (*nlIter);

      if (dcNode != 0) {
	cout << " (" + dcNode->getName () + ",";
	if (dcNode->caseIsSet (i))
	  cout << dcNode->getCaseState (i) << ") ";
	else
	  cout << "N/A) ";
      }
    }
  }
  cout << endl;
}


void EM::printNodeMarginals (Domain *d)
{
  const NodeList nl = d->getNodes ();

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (dcNode != 0) {
      size_t nStates = dcNode->getNumberOfStates ();

      cout << dcNode->getLabel () + " (" + dcNode->getName () + ")\n";

      for (size_t i = 0; i < nStates; i++)
	cout << " - " << dcNode->getStateLabel (i)
	     << ": " << dcNode->getBelief (i) << endl;
    }
  }
}
