# include "hugin"

# include <vector>
# include <string>
# include <cstdio>
# include <iostream>
# include <exception>

using namespace HAPI;
using namespace std;

class Adapt {
public:
  Adapt (const string &fileName);

private:
  void specifyLearningParameters (Domain *d);
  void printLearningParameters (Domain *d);
  void enterCase (Domain *d);
  void printCase (Domain *d);
  void printNodeMarginals (Domain *d);
};


int main (int argc, char *argv[])
{
  new Adapt (string (argv[1]));

  return 0;
}


Adapt::Adapt (const string &fileName)
{
  string netFileName = fileName + ".net";
  Domain d (netFileName, NULL);

  string logFileName = fileName + ".log";
  FILE *logFile = fopen (logFileName.c_str (), "w");
  d.setLogFile (logFile);

  d.compile ();

  specifyLearningParameters (&d);
  printLearningParameters (&d);

  enterCase (&d);

  printCase (&d);

  d.propagate ();

  d.adapt ();

  d.initialize ();

  printNodeMarginals (&d);

  d.saveAsNet ("q.net");
}


void Adapt::specifyLearningParameters (Domain *d)
{
  NodeList nl = d->getNodes ();
  NumberList data;

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *node = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (node != 0) {
      Table *table = node->getExperienceTable ();

      data.clear ();
      data.insert (data.end (), table->getSize (), 1);

      table->setData (data);
    }
  }

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *node = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (node != 0) {
      Table *table = node->getFadingTable ();

      data.clear ();
      data.insert (data.end (), table->getSize (), 1);

      table->setData (data);
    }
  }
}


void Adapt::printLearningParameters (Domain *d)
{
  NodeList nl = d->getNodes ();

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (dcNode != 0) {
      cout << dcNode->getLabel () << " (" << dcNode->getName ()
	   << "): " << endl;

      cout << "   ";
      if (dcNode->hasExperienceTable ()) {
	Table *table = dcNode->getExperienceTable ();
	NumberList data = table->getData ();
	size_t tblSize = table->getSize ();

	for (size_t i = 0; i < tblSize; i++)
	  cout << data[i] << " ";

	cout << endl;
      }
      else
	cout << "No experience table" << endl;

      cout << "   ";
      if (dcNode->hasFadingTable ()) {
	Table *table = dcNode->getFadingTable ();
	NumberList data = table->getData ();
	size_t tblSize = table->getSize ();

	for (size_t i = 0; i < tblSize; i++)
	  cout << data[i] << " ";

	cout << endl;
      }
      else
	cout << "No fading table" << endl;
    }
  }
}


void Adapt::enterCase (Domain *d)
{
  NodeList nl = d->getNodes ();

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (dcNode != 0)
      dcNode->selectState (0);
  }

  DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (nl[1]);

  if (dcNode != 0)
    dcNode->retractFindings ();
}


void Adapt::printCase (Domain *d)
{
  NodeList nl = d->getNodes ();

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (dcNode != 0) {
      cout << " (" + dcNode->getName () + ",";
      if (dcNode->isEvidenceEntered ())
	cout << " evidence is entered) ";
      else
	cout << " evidence is not entered) ";
    }
  }
  cout << endl;
}


void Adapt::printNodeMarginals (Domain *d)
{
  NodeList nl = d->getNodes ();

  for (NodeList::const_iterator nlIter = nl.begin (), nlEnd = nl.end ();
       nlIter != nlEnd; ++nlIter) {
    DiscreteChanceNode *dcNode = dynamic_cast<DiscreteChanceNode*> (*nlIter);

    if (dcNode != 0) {
      size_t nStates = dcNode->getNumberOfStates ();

      cout << dcNode->getLabel () + " (" + dcNode->getName () + ")" << endl;

      for (size_t i = 0; i < nStates; i++)
	cout << " - " << dcNode->getStateLabel (i)
	     << ": " << dcNode->getBelief (i) << endl;
    }
  }
}
