From cffd1657066d65d149c32e0fedaef334eefd06ce Mon Sep 17 00:00:00 2001
From: kar <kas2020@protonmail.com>
Date: Mon, 21 Mar 2022 15:15:16 -0400
Subject: [PATCH] initial work to refactoring QSP

---
 src/Science/QSPCount.cpp | 450 ++++++++++++++++++++++-----------------
 1 file changed, 250 insertions(+), 200 deletions(-)

diff --git a/src/Science/QSPCount.cpp b/src/Science/QSPCount.cpp
index e2f73e9..1971562 100644
--- a/src/Science/QSPCount.cpp
+++ b/src/Science/QSPCount.cpp
@@ -58,175 +58,261 @@ public:
   ~QSPVector() { free(data); }
 };
 
-int index(int row, int col, int rows, int cols) {
-  assert(row >= 0 && row < rows);
-  assert(col >= 0 && col < cols);
-  return (row * cols) + col;
-}
+enum QSPType {
+  QSP_u,
+  QSP_v,
+  QSP_w,
+  QSP_ke,
+  QSP_temp,
+  QSP_rho,
+  QSP_salinity,
+};
 
-void QSPCount(const TArrayn::DTArray &t, const TArrayn::DTArray &u,
-              const TArrayn::DTArray &v, const TArrayn::DTArray &w,
-              const char T1_name, const char S1_name, const int NS,
-              const int NT, double T1_max, double S1_max, double T1_min,
-              double S1_min, const int Nx, const int Ny, const int Nz,
-              string filename, const int plotnum, bool mapped,
-              TArrayn::DTArray *xgrid, TArrayn::DTArray *ygrid,
-              TArrayn::DTArray *zgrid) {
+struct QSPOptions {
+  int NS;
+  int NT;
+  string filename;
+  double S1_max;
+  double S1_min;
+  double T1_max;
+  double T1_min;
+  string T1_name;
+  string S1_name;
+};
 
-  int local_rank;
-  MPI_Comm_rank(MPI_COMM_WORLD, &local_rank);
-  const TArrayn::DTArray *T1_ptr = NULL, *S1_ptr = NULL;
+struct QSPData {
+  TArrayn::DTArray *u;
+  TArrayn::DTArray *v;
+  TArrayn::DTArray *w;
+  TArrayn::DTArray *temp;
+  TArrayn::DTArray *rho;
+  TArrayn::DTArray *salinity;
+  TArrayn::DTArray *xgrid;
+  TArrayn::DTArray *ygrid;
+  TArrayn::DTArray *zgrid;
+  int Nx;
+  int Ny;
+  int Nz;
+  int plotnum;
+  bool mapped;
+};
 
-  switch (T1_name) {
-  case 'u':
-    T1_ptr = &u;
-    break;
-  case 'v':
-    T1_ptr = &v;
-    break;
-  case 'w':
-    T1_ptr = &w;
-    break;
-  case 'k':
-    break;
-  case 't':
-    T1_ptr = &t;
-    break;
-  case 'T':
-    T1_ptr = &t;
-    break;
-  default:
-    return;
+void QSP_write(int local_rank, const QSPVector &local_hist,
+               const QSPOptions &qsp_options, const QSPData &qsp_data) {
+  if (local_rank == 0) {
+    QSPVector glob_hist(qsp_options.NS, qsp_options.NT);
+    MPI_Reduce(local_hist.raw(), glob_hist.raw(), // send and receive buffers
+               qsp_options.NS * qsp_options.NT,   // Count
+               MPI_DOUBLE,                        // datatype
+               MPI_SUM, 0,      // Reduction operator and root process #
+               MPI_COMM_WORLD); // Communicator
+    string filename = qsp_options.filename + "." +
+                      boost::lexical_cast<string>(qsp_data.plotnum) + ".csv";
+    std::fstream outfile;
+    outfile.open(filename.c_str(), std::ios_base::out);
+    if (outfile.is_open()) {
+      outfile << T1_max << ',' << T1_min << ',' << S1_max << ',' << S1_min;
+      for (int i = 4; i < NT; i++) {
+        outfile << ',' << 0;
+      }
+      outfile << std::endl;
+      for (int ii = 0; ii < NS; ii++) {
+        outfile << glob_hist(ii, 0);
+        for (int jj = 1; jj < NT; jj++) {
+          outfile << ',' << glob_hist(ii, jj);
+        }
+        outfile << std::endl;
+      }
+    }
+  } else {
+    MPI_Reduce(local_hist.raw(), NULL, // send and receive buffers
+               NS * NT, MPI_DOUBLE,    // count and datatype
+               MPI_SUM, 0,             // Reduction operator and root process
+               MPI_COMM_WORLD);        // Communicator
+  }
+}
+
+QSPType QSPConvert(const std::string &name) {
+  QSPType converted_type;
+  if (name.compare("u") == 0) {
+    converted_type = QSP_u;
+  } else if (name.compare("v") == 0) {
+    converted_type = QSP_v;
+  } else if (name.compare("w") == 0) {
+    converted_type = QSP_w;
+  } else if (name.compare("ke") == 0) {
+    converted_type = QSP_ke;
+  } else if (name.compare("temp") == 0) {
+    converted_type = QSP_temp;
+  } else if (name.compare("rho") == 0) {
+    converted_type = QSP_rho;
+  } else if (name.compare("salinity") == 0) {
+    converted_type = QSP_salinity;
   }
+  return converted_type;
+}
+
+TArrayn::DTArray *QSPPtr(const QSPData &qsp_data, const QSPType &type) {
+  TArrayn::DTArray *ptr = NULL;
 
-  switch (S1_name) {
-  case 'u':
-    S1_ptr = &u;
+  switch (type) {
+  case QSP_u:
+    ptr = qsp_data.u;
     break;
-  case 'v':
-    S1_ptr = &v;
+  case QSP_v:
+    ptr = qsp_data.v;
     break;
-  case 'w':
-    S1_ptr = &w;
+  case QSP_w:
+    ptr = qsp_data.w;
     break;
-  case 'k':
+  case QSP_ke: // This is an odd case.
     break;
-  case 't':
-    S1_ptr = &t;
+  case QSP_rho:
+    ptr = qsp_data.rho;
     break;
-  case 'T':
-    S1_ptr = &t;
+  case QSP_temp:
+    ptr = qsp_data.temp;
+    break;
+  case QSP_salinity:
+    ptr = qsp_data.salinity;
     break;
-  default:
-    return;
   }
 
-  int i_low = u.lbound(blitz::firstDim);
-  int j_low = u.lbound(blitz::secondDim);
-  int k_low = u.lbound(blitz::thirdDim);
-  int i_high = u.ubound(blitz::firstDim);
-  int j_high = u.ubound(blitz::secondDim);
-  int k_high = u.ubound(blitz::thirdDim);
+  return ptr;
+}
 
-  // This block calculates the global min/max values, in case the user
-  // didn't want to specify them.
+void QSPMaxMin(const QSPType &T1_type, const QSPType &S1_type,
+               TArrayn::DTArray *T1_ptr, TArrayn::DTArray *S1_ptr,
+               QSPOptions &qsp_options, const QSPData &qsp_data) {
   double double_max = std::numeric_limits<double>::max();
-  double double_min = std::numeric_limits<double>::min();
-  if (T1_max == double_max || S1_max == double_max || S1_max == double_min ||
-      S1_min == double_min) {
-
-    // If One of the variables is K.E or rho, we need to hand-roll the max/min
-    // since, in general, the index of max(u) is the same as the index of max(v)
-    if (T1_name == 'k' || T1_name == 't' || S1_name == 'k' || S1_name == 't') {
-      double ke_max = -double_max, ke_min = double_max;
-      double rho_max = -double_max, rho_min = double_max;
 
-      // Main hand-rolled loop
-      for (int i = i_low; i <= i_high; i++) {
-        for (int j = j_low; j <= j_high; j++) {
-          for (int k = k_low; k <= k_high; k++) {
-            double ke_current = 0, tmp = 0;
-            if (Nx > 1) {
-              tmp = u(i, j, k);
-              ke_current += tmp * tmp;
-            }
+  // If One of the variables is K.E or rho, we need to hand-roll the max / min
+  // since, in general, the index of max(u) is the same as the index of max(v)
+  if (T1_type == QSP_ke || T1_type == QSP_rho || S1_type == QSP_ke ||
+      S1_type == QSP_rho) {
+    double ke_max = -double_max, ke_min = double_max;
+    double rho_max = -double_max, rho_min = double_max;
+    // Main hand-rolled loop
+    for (int i = i_low; i <= i_high; i++) {
+      for (int j = j_low; j <= j_high; j++) {
+        for (int k = k_low; k <= k_high; k++) {
+          double tmp;
+          if (T1_type == QSP_ke || S1_type == QSP_ke) {
+            double ke_current = 0;
+            tmp = (*qsp_data.u)(i, j, k);
+            ke_current += tmp * tmp;
             if (Ny > 1) {
-              tmp = v(i, j, k);
-              ke_current += tmp * tmp;
-            }
-            if (Nz > 1) {
-              tmp = w(i, j, k);
+              tmp = (*qsp_data.v)(i, j, k);
               ke_current += tmp * tmp;
             }
+            tmp = (*qsp_data.w)(i, j, k);
+            ke_current += tmp * tmp;
             ke_current = 0.5 * ke_current;
-            double rho_current = eqn_of_state_t(t(i, j, k));
             ke_max = ke_current > ke_max ? ke_current : ke_max;
             ke_min = ke_current < ke_min ? ke_current : ke_min;
+          }
+          if (T1_type == QSP_rho || S1_type == QSP_rho) {
+            double rho_current = eqn_of_state_t(t(i, j, k));
             rho_max = rho_current > rho_max ? rho_current : rho_max;
             rho_min = rho_current < rho_min ? rho_current : rho_min;
           }
         }
       }
+    }
+    double glob_ke_max, glob_ke_min, glob_rho_max, glob_rho_min;
+    MPI_Allreduce(&ke_max, &glob_ke_max, 1, MPI_DOUBLE, MPI_MAX,
+                  MPI_COMM_WORLD);
+    MPI_Allreduce(&ke_min, &glob_ke_min, 1, MPI_DOUBLE, MPI_MIN,
+                  MPI_COMM_WORLD);
+    MPI_Allreduce(&rho_max, &glob_rho_max, 1, MPI_DOUBLE, MPI_MAX,
+                  MPI_COMM_WORLD);
+    MPI_Allreduce(&rho_min, &glob_rho_min, 1, MPI_DOUBLE, MPI_MIN,
+                  MPI_COMM_WORLD);
+    switch (T1_type) {
+    case QSP_ke:
+      qsp_options.T1_max = glob_ke_max;
+      qsp_options.T1_min = glob_ke_min;
+      break;
+    case QSP_rho:
+      qsp_options.T1_max = glob_rho_max;
+      qsp_options.T1_min = glob_rho_min;
+      break;
+    default:
+      qsp_options.T1_max = psmax(max(*T1_ptr));
+      qsp_options.T1_min = psmin(min(*T1_ptr));
+      break;
+    }
+    switch (S1_type) {
+    case QSP_ke:
+      qsp_options.S1_max = glob_ke_max;
+      qsp_options.S1_min = glob_ke_min;
+      break;
+    case QSP_rho:
+      qsp_options.S1_max = glob_rho_max;
+      qsp_options.S1_min = glob_rho_min;
+      break;
+    default:
+      qsp_options.S1_max = psmax(max(*S1_ptr));
+      qsp_options.S1_min = psmin(min(*S1_ptr));
+      break;
+    }
+  } else { // !(cond1 || cond2) == !cond1 && !cond2
+    qsp_options.S1_max = psmax(max(*S1_ptr));
+    qsp_options.S1_min = psmin(min(*S1_ptr));
+    qsp_options.T1_max = psmax(max(*T1_ptr));
+    qsp_options.T1_min = psmin(min(*T1_ptr));
+  }
+}
 
-      double glob_ke_max, glob_ke_min, glob_rho_max, glob_rho_min;
-      MPI_Allreduce(&ke_max, &glob_ke_max, 1, MPI_DOUBLE, MPI_MAX,
-                    MPI_COMM_WORLD);
-      MPI_Allreduce(&ke_min, &glob_ke_min, 1, MPI_DOUBLE, MPI_MIN,
-                    MPI_COMM_WORLD);
-      MPI_Allreduce(&rho_max, &glob_rho_max, 1, MPI_DOUBLE, MPI_MAX,
-                    MPI_COMM_WORLD);
-      MPI_Allreduce(&rho_min, &glob_rho_min, 1, MPI_DOUBLE, MPI_MIN,
-                    MPI_COMM_WORLD);
+void QSPCount(QSPOptions qsp_options, QSPData qsp_data) {
 
-      switch (T1_name) {
-      case 'k':
-        T1_max = glob_ke_max;
-        T1_min = glob_ke_min;
-        break;
-      case 't':
-        T1_max = glob_rho_max;
-        T1_min = glob_rho_min;
-        break;
-      default:
-        T1_max = psmax(max(*T1_ptr));
-        T1_min = psmin(min(*T1_ptr));
-        break;
-      }
+  int local_rank;
+  MPI_Comm_rank(MPI_COMM_WORLD, &local_rank);
 
-      switch (S1_name) {
-      case 'k':
-        S1_max = glob_ke_max;
-        S1_min = glob_ke_min;
-        break;
-      case 't':
-        S1_max = glob_rho_max;
-        S1_min = glob_rho_min;
-        break;
-      default:
-        S1_max = psmax(max(*S1_ptr));
-        S1_min = psmin(min(*S1_ptr));
-        break;
-      }
-    } else { // !(cond1 || cond2) == !cond1 && !cond2
-      S1_max = psmax(max(*S1_ptr));
-      S1_min = psmin(min(*S1_ptr));
-      T1_max = psmax(max(*T1_ptr));
-      T1_min = psmin(min(*T1_ptr));
-    }
+  // Find out what
+  QSPType S1_type = QSPConvert(qsp_options.S1_name);
+  QSPType T1_type = QSPConvert(qsp_options.T1_name);
+  TArrayn::DTArray *S1_ptr = QSPPtr(qsp_data, S1_type);
+  TArrayn::DTArray *T1_ptr = QSPPtr(qsp_data, T1_type);
+
+  if ((!S1_ptr && S1_type != QSP_ke) || (!T1_ptr && T1_type != QSP_ke)) {
+    std::cout << "Not enough data was provided for the requested tracer. "
+                 "Aborting...\n";
+    return;
   }
 
-  double hS = (S1_max - S1_min) / (double)NS;
-  double hT = (T1_max - T1_min) / (double)NT;
+  int i_low, j_low, k_low, i_high, j_high, k_high;
+  TArrayn::DTArray *temp_ptr;
+  if (S1_ptr) { // If S1 is not ke
+    temp_ptr = S1_ptr;
+  } else { // If S1 is ke we know u must exist
+    temp_ptr = qsp_data.u;
+  }
+  int i_low = S1_ptr->lbound(blitz::firstDim);
+  int j_low = S1_ptr->lbound(blitz::secondDim);
+  int k_low = S1_ptr->lbound(blitz::thirdDim);
+  int i_high = S1_ptr->ubound(blitz::firstDim);
+  int j_high = S1_ptr->ubound(blitz::secondDim);
+  int k_high = S1_ptr->ubound(blitz::thirdDim);
+
+  double double_max = std::numeric_limits<double>::max();
+  if (qsp_options.T1_max == double_max || qsp_options.S1_max == double_max ||
+      qsp_options.T1_min == -double_max || qsp_options.S1_min == -double_max) {
+    QSPMaxMin(T1_type, S1_type, T1_ptr, S1_ptr, qsp_options, qsp_data);
+  }
+
+  double hS = (qsp_options.S1_max - qsp_options.S1_min) / (double)NS;
+  double hT = (qsp_options.T1_max - qsp_options.T1_min) / (double)NT;
   double hS_inv = 1 / hS;
   double hT_inv = 1 / hT;
 
-  QSPVector local_hist(NS, NT);
-  QSPVector global_z_max(Nx, Ny);
-  QSPVector global_z_min(Nx, Ny);
+  QSPVector local_hist(qsp_options.NS, qsp_options.NT);
+  QSPVector global_z_max(qsp_data.Nx, qsp_data.Ny);
+  QSPVector global_z_min(qsp_data.Nx, qsp_data.Ny);
   // Find the range of Lz values per 2D-slice
-  if (mapped) {
-    QSPVector local_z_max(Nx, Ny);
-    QSPVector local_z_min(Nx, Ny);
+  if (qsp_data.mapped) {
+    QSPVector local_z_max(qsp_data.Nx, qsp_data.Ny);
+    QSPVector local_z_min(qsp_data.Nx, qsp_data.Ny);
     //  We are slicing as if we are doing zgrid[i, j, :]
     for (int ii = i_low; ii <= i_high; ii++) {
       for (int jj = j_low; jj <= j_high; jj++) {
@@ -236,7 +322,7 @@ void QSPCount(const TArrayn::DTArray &t, const TArrayn::DTArray &u,
         double tmp_z_max = -tmp_z_min;
         double tmp;
         for (int kk = k_low; kk <= k_high; kk++) {
-          tmp = (*zgrid)(ii, jj, kk);
+          tmp = (*qsp_data.zgrid)(ii, jj, kk);
           if (tmp > tmp_z_max) {
             tmp_z_max = tmp;
           } else if (tmp < tmp_z_min) {
@@ -259,64 +345,57 @@ void QSPCount(const TArrayn::DTArray &t, const TArrayn::DTArray &u,
     for (int j = j_low; j <= j_high; j++) {
       for (int k = k_low; k <= k_high; k++) {
 
-        if (T1_name == 'k') {
+        switch (T1_type) {
+        case QSP_ke:
           Tval = 0;
-          if (Nx > 1) {
-            tmp = u(i, j, k);
-            Tval += tmp * tmp;
-          }
+          tmp = (*qsp_data.u)(i, j, k);
+          Tval += tmp * tmp;
+          tmp = (*qsp_data.w)(i, j, k);
+          Tval += tmp * tmp;
           if (Ny > 1) {
-            tmp = v(i, j, k);
-            Tval += tmp * tmp;
-          }
-          if (Nz > 1) {
-            tmp = w(i, j, k);
+            tmp = (*qsp_data.v)(i, j, k);
             Tval += tmp * tmp;
           }
           Tval = 0.5 * Tval;
-        } else if (T1_name == 't') {
+          break;
+        case QSP_rho:
           tmp = (*T1_ptr)(i, j, k);
           Tval = eqn_of_state_t(tmp);
-        } else {
+          break;
+        default:
           Tval = (*T1_ptr)(i, j, k);
-        }
-        int idxT = floor((Tval - T1_min) * hT_inv);
-        if (idxT < 0) {
-          idxT = 0;
-        } else if (idxT >= NT) {
-          idxT = 0;
+          break;
         }
 
-        if (S1_name == 'k') {
+        switch (S1_type) {
+        case QSP_ke:
           Sval = 0;
-          if (Nx > 1) {
-            tmp = u(i, j, k);
-            Sval += tmp * tmp;
-          }
+          tmp = (*qsp_data.u)(i, j, k);
+          Sval += tmp * tmp;
+          tmp = (*qsp_data.w)(i, j, k);
+          Sval += tmp * tmp;
           if (Ny > 1) {
-            tmp = v(i, j, k);
-            Sval += tmp * tmp;
-          }
-          if (Nz > 1) {
-            tmp = w(i, j, k);
+            tmp = (*qsp_data.v)(i, j, k);
             Sval += tmp * tmp;
           }
           Sval = 0.5 * Sval;
-        } else if (S1_name == 't') {
+          break;
+        case QSP_rho:
           tmp = (*S1_ptr)(i, j, k);
           Sval = eqn_of_state_t(tmp);
-        } else {
+          break;
+        default:
           Sval = (*S1_ptr)(i, j, k);
+          break;
         }
+
         int idxS = floor((Sval - S1_min) * hS_inv);
-        if (idxS < 0) {
-          idxS = 0;
-        } else if (idxS >= NS) {
-          idxS = 0;
-        }
+        int idxT = floor((Tval - T1_min) * hT_inv);
+        idxS = std::max(std::min(idxS, qsp_options.NS), 0);
+        idxT = std::max(std::min(idxT, qsp_options.NT), 0);
 
         double volume_weight;
-        if (mapped) {
+        if (qsp_data.mapped) {
           // Calculate the Lz range
           double Lzmax_now = global_z_max(i, j);
           double Lzmin_now = global_z_min(i, j);
@@ -346,40 +425,11 @@ void QSPCount(const TArrayn::DTArray &t, const TArrayn::DTArray &u,
           volume_weight = 1.0;
         }
 
-        // local_hist[index(idxS, idxT, NS, NT)] += volume_weight;
         local_hist(idxS, idxT) += volume_weight;
       }
     }
   }
 
   MPI_Barrier(MPI_COMM_WORLD); // Wait for everyone to finish
-  if (local_rank == 0) {
-    QSPVector glob_hist(NS, NT);
-    MPI_Reduce(local_hist.raw(), glob_hist.raw(), // send and receive buffers
-               NS * NT, MPI_DOUBLE,               // count and datatype
-               MPI_SUM, 0,      // Reduction operator and root process #
-               MPI_COMM_WORLD); // Communicator
-    filename = filename + "." + boost::lexical_cast<string>(plotnum) + ".csv";
-    std::fstream outfile;
-    outfile.open(filename.c_str(), std::ios_base::out);
-    if (outfile.is_open()) {
-      outfile << T1_max << ',' << T1_min << ',' << S1_max << ',' << S1_min;
-      for (int i = 4; i < NT; i++) {
-        outfile << ',' << 0;
-      }
-      outfile << std::endl;
-      for (int ii = 0; ii < NS; ii++) {
-        outfile << glob_hist[index(ii, 0, NS, NT)];
-        for (int jj = 1; jj < NT; jj++) {
-          outfile << ',' << glob_hist[index(ii, jj, NS, NT)];
-        }
-        outfile << std::endl;
-      }
-    }
-  } else {
-    MPI_Reduce(local_hist.raw(), NULL, // send and receive buffers
-               NS * NT, MPI_DOUBLE,    // count and datatype
-               MPI_SUM, 0,             // Reduction operator and root process #
-               MPI_COMM_WORLD);        // Communicator
-  }
+  QSP_write(local_rank, local_hist, qsp_options, qsp_data);
 }
-- 
GitLab