# include "hugin"

# include <iostream>
# include <cmath>

using namespace HAPI;
using namespace std;

void printBeliefsAndUtilities (Domain*);
bool containsUtilities (const NodeList&);


/* This function parses the given NET file, compiles the network, and
   prints the prior beliefs and expected utilities of all nodes.  If a
   case file is given, the function loads the file, propagates the
   evidence, and prints the updated results.

   If the network is a LIMID, we assume that we should compute
   policies for all decisions (rather than use the ones specified in
   the NET file).  Likewise, we update the policies when new evidence
   arrives.
*/
void loadAndPropagate (const char *netName, const char *caseFileName)
{
    DefaultParseListener pl;
    string netFileName = netName;
    Domain domain (netFileName + ".net", &pl);
    string logFileName = netFileName + ".log";
    FILE *logFile = fopen (logFileName.c_str(), "w");

    if (logFile == NULL)
    {
	cerr << "Could not open \"" << logFileName << "\"\n";
	exit (EXIT_FAILURE);
    }

    domain.setLogFile (logFile);
    domain.triangulate (H_TM_BEST_GREEDY);
    domain.compile();
    domain.setLogFile (NULL);
    fclose (logFile);

    bool hasUtilities = containsUtilities (domain.getNodes());

    if (!hasUtilities)
	cout << "Prior beliefs:\n";
    else
    {
	domain.updatePolicies();
	cout << "Overall expected utility: " << domain.getExpectedUtility()
	     << "\n\nPrior beliefs (and expected utilities):\n";
    }

    printBeliefsAndUtilities (&domain);

    if (caseFileName != NULL)
    {
	domain.parseCase (caseFileName, &pl);
	cout << "\n\nPropagating the evidence specified in \""
	     << caseFileName << "\"\n";

	domain.propagate (H_EQUILIBRIUM_SUM, H_MODE_NORMAL);

	cout << "\nP(evidence) = " << domain.getNormalizationConstant() << endl;

	if (!hasUtilities)
	    cout << "\nUpdated beliefs:\n";
	else
	{
	    domain.updatePolicies();
	    cout << "\nOverall expected utility: "
		 << domain.getExpectedUtility()
		 << "\n\nUpdated beliefs (and expected utilities):\n";
	}

	printBeliefsAndUtilities (&domain);
    }
}


/** Print the beliefs and expected utilities of all nodes in the domain. */

void printBeliefsAndUtilities (Domain *domain)
{
    NodeList nodes = domain->getNodes();
    bool hasUtilities = containsUtilities (nodes);

    for (NodeList::const_iterator it = nodes.begin(); it != nodes.end(); ++it)
    {
	Node *node = *it;

	Category category = node->getCategory();
	char type = (category == H_CATEGORY_CHANCE ? 'C'
		     : category == H_CATEGORY_DECISION ? 'D'
		     : category == H_CATEGORY_UTILITY ? 'U' : 'F');

	cout << "\n[" << type << "] " << node->getLabel()
	     << " (" << node->getName() << ")\n";

	if (category == H_CATEGORY_UTILITY)
	{ 
	    UtilityNode *uNode = dynamic_cast<UtilityNode*> (node);
	    cout << "  - Expected utility: " << uNode->getExpectedUtility()
		 << endl;
	}
	else if (category == H_CATEGORY_FUNCTION)
	{
	    try
	    {
		FunctionNode *fNode = dynamic_cast<FunctionNode*> (node);
		double value = fNode->getValue ();
		cout << "  - Value: " << value << endl;
	    }
	    catch (const ExceptionHugin& e)
	    {
		cout << "  - Value: N/A\n";
	    }
	}
	else if (node->getKind() == H_KIND_DISCRETE)
	{
	    DiscreteNode *dNode = dynamic_cast<DiscreteNode*> (node);

	    for (size_t i = 0, n = dNode->getNumberOfStates(); i < n; i++)
	    {
		cout << "  - " << dNode->getStateLabel (i)
		     << " " << dNode->getBelief (i);
		if (hasUtilities)
		    cout << " (" << dNode->getExpectedUtility (i) << ")";
		cout << endl;
	    }
	}
	else
	{
	    ContinuousChanceNode *ccNode
		= dynamic_cast<ContinuousChanceNode*> (node);

	    cout << "  - Mean : " << ccNode->getMean() << endl;
	    cout << "  - SD   : " << sqrt (ccNode->getVariance()) << endl;
	}
    }
}


/** Are there utility nodes in the list? */

bool containsUtilities (const NodeList& list)
{
    for (size_t i = 0, n = list.size(); i < n; i++)
	if (list[i]->getCategory() == H_CATEGORY_UTILITY)
	    return true;

    return false;
}


/*
 * Load a Hugin NET file, compile the network, and print the results.
 * If a case file is specified, load it, propagate the evidence, and
 * print the results.
 */

int main (int argc, const char *argv[])
{
    if (argc < 2 || argc > 3)
    {
	cerr << "Usage: " << argv[0] << " <NET_file_name> [<case_file_name>]\n";
	exit (EXIT_FAILURE);
    }

    loadAndPropagate (argv[1], argv[2]);

    return 0;
}
