/* ----------------------------------------------------------------------
 
    This program tests the use of MPI Collective Communications
    subroutines.  The structure of the program is:
 
       --  Master node queries for random number seed
       --  The seed is sent to all nodes (Lab project)
       --  Each node calculates one random number based on the seed
           and the rank 
       --  The node with highest rank calculates the mean value
           of the random numbers (Lab project)
       --  4 more random numbers are generated by each node
       --  The maximum value and the standard deviation of all
           generated random numbers are calculated, and the
           results are made available to all nodes (Lab project)
 
    Also provided is a service routine GetStats(rnum,N,data), where
 
          rnum:  array of random numbers (INPUT)
          N:     number of elements in rnum (INPUT)
          outd:  array of size 2 containing the maximum value and
                 standard deviation (OUTPUT)
 
  ---------------------------------------------------------------------- */
#include <stdio.h>
#include <stddef.h>
#include <stdlib.h>
#include <math.h>
#include "mpi.h"

#define NUM_RANDS 5

void GetStats(float* rnum, int n, float* rval);

void main(int argc, char** argv)
{
	FILE          *fp, *GetSeed;
	int 	        ii, numtasks, taskid;
	unsigned int  seed;
	float	        meanv, randnum[NUM_RANDS], rnum[100], rval[2], sum;  

	MPI_Init(&argc, &argv);
	MPI_Comm_rank(MPI_COMM_WORLD, &taskid);
	MPI_Comm_size(MPI_COMM_WORLD, &numtasks);

	fp = fopen("c.data", "w");
   if (fp == (FILE *) NULL)
   {
      printf ("Cannot open c.data\n");
      exit(1);
   }

	/*  Get random number seed */
	if( taskid == 0 )    
	{
		GetSeed = fopen("./c.seed", "r");
		if (fp == (FILE *) NULL)
		{
			printf ("Cannot open c.seed\n");
			exit(1);
		}

		fscanf(GetSeed, "%u", &seed);
		printf("seed is %u\n", seed);
		fclose(GetSeed);
	}
  
	/* 
   ================================================
   Project:  Send seed from task 0 to all nodes.   
   ================================================ 
	*/
	MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);

	printf("\n Task %d after broadcast; seed = %u", taskid, seed);

	/* Get a random number */
	srand(seed + taskid);
	randnum[0] = 100.* (float) rand() / (float) RAND_MAX; 
	
	/* 
   ============================================================== 
   Project:  Have the node with highest rank calculate the      
             mean value of the random numbers and store result  
             in the variable "meanv".                           
   ============================================================== 
	*/
	MPI_Reduce(randnum, &sum, 1, MPI_FLOAT, MPI_SUM, numtasks-1, 
             MPI_COMM_WORLD);
  	meanv = sum / (float) numtasks;

	/* 
   ============================================================== 
	Only one task will be able to compute the correct meanv,
	but here we will have all tasks compute it and write to
	standard out for demonstration purposes.  Then, only the
	task numtask-1 will write the result to the file.
   ============================================================== 
	*/
	printf("\n Task %d after mean value; ", taskid);
	printf("random[0] = %8.3f sum = %8.3f mean = %8.3f",
		     randnum[0], sum, meanv);

	/* Highest task writes out mean value */
	if(taskid == (numtasks - 1))
		fprintf(fp, " For seed = %d    mean value = %10.3f\n", seed, meanv);

	/*  Generate 4 more random numbers */
	for(ii = 1; ii < NUM_RANDS; ii++)
		randnum[ii] = 100. * (float) rand() / (float) RAND_MAX; 
	
	printf("\n Task %d  random numbers : randnum(1:%d) =", taskid, NUM_RANDS);
	for(ii = 0; ii < NUM_RANDS; ii++)
		printf(" %8.3f", randnum[ii]);

	/*  initialize the receiving buffer */
	for(ii = 0; ii < 100; ii++)
		rnum[ii] =0.0;
	/* 
   ==================================================================
   Project:  Calculate the maximum value and standard deviation of 
             all random numbers generated, and make results known
             to all nodes.
   Method 1:  Use GATHER followed by BCAST
   Method 2:  Use ALLGATHER
   ================================================================== 
	*/


	/*   ------  Method 1   ----------- */ 
	MPI_Gather(randnum, NUM_RANDS, MPI_FLOAT, rnum, NUM_RANDS, MPI_FLOAT, 0, MPI_COMM_WORLD);

	if(taskid == 0) 
		GetStats(rnum, NUM_RANDS * numtasks, rval);

	MPI_Bcast(rval, 2, MPI_FLOAT, 0, MPI_COMM_WORLD );

	/*  Only taskid = 0 will be guaranteed to have rnum to print  */
	printf("\n Task %d after Method 1, rnum(%d:%d) =",
		    taskid, taskid * NUM_RANDS + 1, taskid * NUM_RANDS + NUM_RANDS);
	for(ii = taskid * NUM_RANDS; ii < taskid * NUM_RANDS + NUM_RANDS; ii++)
		printf(" %8.3f", rnum[ii]); 

	if(taskid == (numtasks-1)) 
		fprintf(fp, " (Max, S.D.) = %10.3f%10.3f\n", rval[0], rval[1]);

	/*   ------  Method 2   -----------  */
	MPI_Allgather(randnum, NUM_RANDS, MPI_FLOAT, rnum, NUM_RANDS, MPI_FLOAT, MPI_COMM_WORLD);
	GetStats( rnum, NUM_RANDS * numtasks, rval);

	/*  all tasks will have their rnum by the time the print is called */
	printf("\n Task %d after Method 2, rnum(%d:%d) =",
		    taskid, taskid * NUM_RANDS + 1, taskid * NUM_RANDS + NUM_RANDS);
	for(ii = taskid * NUM_RANDS; ii < taskid * NUM_RANDS + NUM_RANDS; ii++)
		printf( " %8.3f", rnum[ii] );
	printf("\n" );
  
	if(taskid == (numtasks-1)) 
		fprintf(fp," (Max, S.D.) = %10.3f%10.3f\n", rval[0], rval[1]);

	fclose( fp );
	MPI_Finalize();
}

 
/* ----------------------------------------------------------------------
   Service routine GetStats( rnum, N, data), where
 
       rnum:  array of random numbers (INPUT)
       N:     number of elements in rnum (INPUT)
       outd:  array of size 2 containing the maximum value and
              standard deviation (OUTPUT)
 
  ---------------------------------------------------------------------- */

void GetStats(float* rnum, int N, float* outd)
{
 	float  meanv, sdev, sum;
	int    ii;

	sum   = 0.;
	*outd = 0.;

	for(ii = 0; ii < N; ii++)
	{
		sum += *(rnum + ii);
		if(*(rnum + ii) > *outd) 
			*outd = *(rnum + ii);
	}

	meanv = sum / (float) N;
	sdev = 0.;
	for(ii = 0; ii < N; ii++)
		sdev += (*(rnum + ii) - meanv) * (*(rnum + ii) - meanv);

	*(outd + 1) = sqrt(sdev / (float) N);
}