# include "hugin.h"
# include <stdlib.h>
# include <string.h>


/* This program shows how the Hugin C API can be used for learning the
   parameters (conditional probability tables) of a Bayesian belief network
   given a file of case data.  The network and case data must be provided
   via input files.  Then, the conditional probability tables are computed
   from the data using the EM algorithm.  Finally, the updated network is
   saved as a NET file named "EMresult.net".
*/

void perform_EM_learning (h_string_t, h_string_t);

int main (int argc, char *argv[])
{
    if (argc != 3)
    {
	fprintf (stderr, "Usage: %s <NET_file_name> <data_file_name>\n",
		 argv[0]);
	exit (EXIT_FAILURE);
    }

    perform_EM_learning (argv[1], argv[2]);

    return 0;
}


/* A simple parse error handler: It prints the error message on stderr. */

void error_handler (h_location_t line_no, h_string_t message, void *data)
{
    fprintf (stderr, "Error at line %d: %s\n", line_no, message);
}

/* This function is used when a Hugin API error is detected.
   An error message is printed on stderr, and the program is exited.
*/

void print_error (void)
{
    fprintf (stderr, "Error: %s\n", h_error_description (h_error_code ()));
    exit (EXIT_FAILURE);
}


/* Perform EM learning on the belief network supplied in the NET file
   <net_file_name> (the network must be a non-OOBN network) using the
   data supplied in the file <data_file_name>.

   The result of EM learning is a network with updated conditional
   probability tables (CPTs).  The updated network is saved as a NET
   file named "EMresult.net".

   The format of NET files is described in the Hugin API Reference
   Manual, Chapter 12, and the format of data files is described in
   the same document, Section 11.2.
*/

void randomize_CPTs (h_domain_t);
void initialize_learning_parameters (h_domain_t);

# define EM_RESULT_FILE "EMresult.net"

void perform_EM_learning (h_string_t net_file_name, h_string_t data_file_name)
{
    size_t l = strlen (net_file_name);
    char *file_name_buffer;
    h_domain_t domain;
    FILE *log_file;

    if (l >= 4 && strcmp (net_file_name + (l - 4), ".net") == 0)
	l -= 4;

    if ((file_name_buffer = malloc (l + 5)) == NULL)
    {
	fprintf (stderr, "Out of memory\n");
	exit (EXIT_FAILURE);
    }

    strcpy (file_name_buffer, net_file_name);
    strcpy (file_name_buffer + l, ".net");

    printf ("Parsing NET file \"%s\" ...\n", file_name_buffer);

    if ((domain = h_net_parse_domain (file_name_buffer, error_handler, NULL))
	== NULL)
	print_error ();

    printf ("Loading cases from \"%s\" ...\n", data_file_name);

    if (h_domain_parse_cases (domain, data_file_name, error_handler, NULL) != 0)
	print_error ();

    strcpy (file_name_buffer + l, ".log");

    if ((log_file = fopen (file_name_buffer, "w")) == NULL)
    {
	fprintf (stderr, "Could not open \"%s\"\n", file_name_buffer);
	exit (EXIT_FAILURE);
    }

    h_domain_set_log_file (domain, log_file);

    printf ("Initializing CPTs ...\n");

    randomize_CPTs (domain);

    printf ("Compiling ...\n");

    if (h_domain_compile (domain) != 0)
	print_error ();

    printf ("Making copies of junction tree tables for fast inference ...\n");

    if (h_domain_save_to_memory (domain) != 0)
	print_error ();

    printf ("Initializing experience tables etc. ...\n");

    initialize_learning_parameters (domain);

    printf ("Learning CPTs using the EM algorithm ...\n");

    if (h_domain_learn_tables (domain) != 0)
	print_error ();
    
    printf ("   Log likelihood: %g\n", h_domain_get_log_likelihood (domain));

    h_domain_set_log_file (domain, NULL);

    fclose (log_file);

    printf ("Saving output of EM as \"%s\" ...\n", EM_RESULT_FILE);

    if (h_domain_save_as_net (domain, EM_RESULT_FILE) != 0)
	print_error ();

    printf ("DONE\n");

    h_domain_delete (domain);

    free (file_name_buffer);
}


/* The conditional probability tables (CPTs) must be initialized before
   EM learning can begin.  Default CPTs contain uniform distributions.
   However, the EM algorithm usually performs better when CPTs are not
   uniform, so we randomize them before starting the EM procedure.

   Note: The CPTs created here are not normalized.  A subsequent compile
   operation will take care of that.
*/

void randomize_CPTs (h_domain_t domain)
{
    h_node_t node = h_domain_get_first_node (domain);

    for (; node != NULL; node = h_node_get_next (node))
	if (h_node_get_category (node) == h_category_chance
	    && h_node_get_kind (node) == h_kind_discrete)
	{
            h_table_t table = h_node_get_table (node);

	    if (table == NULL)
		print_error ();

	    {
                h_number_t *data = h_table_get_data (table);
                size_t k = h_table_get_size (table);

                for (; k > 0; k--, data++)
		  *data = 0.25 + 0.5 * h_domain_get_uniform_deviate (domain);
	    }
	}
}


/* The EM algorithm implemented in Hugin only learns CPTs for nodes having
   "experience" tables.  This procedure makes sure that all discrete chance
   nodes have such tables.  Moreover, the experience counts are reset to
   zero.  This means that the CPTs are learned from the case data only.
   Finally, we also set parameters that control termination of the EM
   algorithm.

   See Section 11.5 of the Hugin API Reference Manual for further details.
*/

void initialize_learning_parameters (h_domain_t domain)
{
    h_node_t node = h_domain_get_first_node (domain);

    for (; node != NULL; node = h_node_get_next (node))
	if (h_node_get_category (node) == h_category_chance
	    && h_node_get_kind (node) == h_kind_discrete)
	{
            h_boolean_t b = h_node_has_experience_table (node);
            h_table_t table = h_node_get_experience_table (node);

	    if (table == NULL)
		print_error ();

            if (b)
            {
                h_number_t *data = h_table_get_data (table);
                size_t k = h_table_get_size (table);

                for (; k > 0; k--, data++)
                    *data = 0.0;
            }
	}

    h_domain_set_log_likelihood_tolerance (domain, 0.000001);
    h_domain_set_max_number_of_em_iterations (domain, 1000);
}
