! ----------------------------------------------------------------------
!
!   This program tests the use of MPI Collective Communications
!   subroutines.  The structure of the program is:
!
!      --  Master node queries for random number seed
!      --  The seed is sent to all nodes (Lab project)
!      --  Each node calculates one random number based on the seed
!          and the rank
!      --  The node with highest rank calculates the mean value
!          of the random numbers (Lab project)
!      --  4 more random numbers are generated by each node
!      --  The maximum value and the standard deviation of all
!          generated random numbers are calculated, and the
!          results are made available to all nodes (Lab project)
!
!   Also provided is a service routine GetStats(rnum,N,data), where
!
!         rnum:  array of random numbers (INPUT)
!         N:     number of elements in rnum (INPUT)
!         outd:  array of size 2 containing the maximum value and
!                standard deviation (OUTPUT)
!
! ----------------------------------------------------------------------
      program CollCom

      include 'mpif.h'
      integer status(MPI_STATUS_SIZE), ierr

      integer 	numtasks, taskid, nrands
      PARAMETER(NUM_RANDS=5)
      real*4	randnum(NUM_RANDS), sum, meanv, rnum(100), rval(2)
!     real*4	randnum(5), sum, meanv, rnum(100), rval(2)
      real*4	seed

      call MPI_INIT(ierr)
      call MPI_COMM_RANK(MPI_COMM_WORLD, taskid, ierr)
      call MPI_COMM_SIZE(MPI_COMM_WORLD, numtasks, ierr)
      NRANDS=NUM_RANDS
!
!     PARAMETER(NUM_RANDS=5)
!
      do i = 1, 100
        rnum(i) = 0.0
      enddo
!
      if( taskid .eq. 0 ) then    !  Get random number seed
         open(unit=20, file='f.seed')
         read(20,*) seed
         close(unit=20)
         print*, 'seed = ',seed
      end if
!
! ================================================
!   Project:  Send seed from task 0 to all nodes.
! ================================================
!
      call MPI_BCAST( seed, 1, MPI_DOUBLE_PRECISION, 0, 
     &                MPI_COMM_WORLD, ierr)

!
      write( *, '(a, i3, a, f12.5)') ' Task', taskid, 
     &      ' after broadcast; seed = ', seed
      call srand( seed + REAL(taskid))
      randnum(1) = 100*rand()    !  Each node generates a random number

!
! ==============================================================
!   Project:  Have the node with highest rank calculate the 
!             mean value of the random numbers and store result
!             in the variable "meanv".
! ==============================================================
!
      call MPI_REDUCE( randnum(1), sum, 1, MPI_REAL, MPI_SUM,
     &                 numtasks-1, MPI_COMM_WORLD, ierr )
! ================================
! Only one task will be able to compute the correct meanv,
! but here we will have all tasks compute it and write to
! standard out for demonstration purposes.  Then, only the
! task numtask-1 will write the result to the file.
! ================================

      meanv = sum/REAL(numtasks)
!
      write(*, '(a, i3, a, f8.3, a, f8.3, a, f8.3)') 
     &      ' Task', taskid, ' after mean value; random(1) =',  
     &      randnum(1), ' sum =', sum, ' mean =', meanv

                                !  Highest task writes out mean value
      if( taskid .eq. (numtasks-1) ) then    
         open( unit=10, file='f.data', status='UNKNOWN' )
      	 write(10, '(a, f12.5, a, f10.3)') ' For seed = ', seed, 
     &         '   mean value = ', meanv
      end if

               !  Each node generates 4 more random numbers
      do ii=2, NUM_RANDS
!     do ii=2, 5
	 randnum(ii) = 100.0*rand()  
      end do
!
      write( *, '(a, i3, a ,i1, a,5f8.3)') ' Task', taskid, 
     &     ' ===>  randnum(1:',NUM_RANDS,') =', (randnum(I), I=1,
     & NUM_RANDS)

! ==================================================================
!   Project:  Calculate the maximum value and standard deviation of 
!             all random numbers generated, and make results known
!             to all nodes.
!   Method 1:  Use GATHER followed by BCAST
!   Method 2:  Use ALLGATHER
! ==================================================================
!
!   ------  Method 1   -----------
      call MPI_GATHER( randnum, NUM_RANDS, MPI_REAL, rnum, NUM_RANDS,
     &  MPI_REAL, 0,   MPI_COMM_WORLD, ierr )
      if( taskid .eq. 0 ) then
         call GetStats( rnum, numtasks*NUM_RANDS, rval )
!        call GetStats( rnum, numtasks*5, rval )
      end if
      call MPI_BCAST( rval, 2, MPI_REAL, 0, MPI_COMM_WORLD, ierr )
!
      write( *, '(a, i3, a, i2,a, i2, a, 5f8.3)') ' Task', taskid, 
     &      ' after Method 1, rnum(',taskid*NRANDS+1,':',
     &taskid*NRANDS+NUM_RANDS,') =', (rnum(I),
     &I=taskid*NRANDS+1, taskid*NRANDS+NRANDS)
      if( taskid.eq.(numtasks-1) )
     &    write( 10,'(a, 2f10.3)') ' (Max, S.D.) = ', rval

!   ------  Method 2   -----------
      call MPI_ALLGATHER( randnum, NUM_RANDS, MPI_REAL, rnum,
     & NUM_RANDS, MPI_REAL,  MPI_COMM_WORLD, ierr )
      call GetStats( rnum, numtasks*NUM_RANDS, rval )

!
      write( *, '(a, i3, a, i2, a, i2, a, 5f8.3)') ' Task', taskid, 
!    &      ' after Method 2, rnum(1:5) =', (rnum(I), I=1, 5)
     &      ' after Method 2, rnum(',taskid*NUM_RANDS+1,':',
     &taskid*NUM_RANDS+NUM_RANDS,') =', (rnum(I), I=taskid*NUM_RANDS+1, 
     &taskid*NUM_RANDS+NUM_RANDS)
      if( taskid .eq. (numtasks-1)) then 
         write( 10,'(a, 2f10.3)') ' (Max, S.D.) = ', rval
         close( 10 )
      end if

      call MPI_FINALIZE( ierr )
      end

! ----------------------------------------------------------------------
!   The service routine GetStats(rnum,N,data), where
!
!         rnum:  array of random numbers (INPUT)
!         N:     number of elements in rnum (INPUT)
!         outd:  array of size 2 containing the maximum value and
!                standard deviation (OUTPUT)
!
! ----------------------------------------------------------------------
      subroutine GetStats( rnum, N, outd )
      real*4 rnum(N), outd(2), sum, meanv

      sum = 0.
      outd(1) = 0.

      do ii=1, N
	 sum = sum + rnum(ii)
	 if( rnum(ii) .gt. outd(1) ) outd(1) = rnum(ii)
      end do

      meanv = sum/REAL(N)
      sdev = 0.

      do ii=1, N
	 sdev = sdev + (rnum(ii) - meanv)**2
      end do
      outd(2) = sqrt( (sdev)/N )

      return
      end