diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index b0cccce1719..68dff4785b3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -21,6 +21,8 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.Map.Entry; @@ -724,21 +726,41 @@ private static void setThreadID(long tid, FederatedRequest[]... frsets) { Arrays.stream(frset).forEach(fr -> fr.setTID(tid)); } + /** + * Sort the entries of the federation map based on their federated ranges + */ + private void sortFederatedRanges() { + int dim = (this.getType() == FType.COL) ? 1 : 0; + + this._fedMap.sort(new Comparator>() { + @Override + public int compare(Pair o1, Pair o2) { + return o1.getLeft().getBeginDimsInt()[dim] - o2.getLeft().getBeginDimsInt()[dim]; + } + }); + } + public void reverseFedMap() { // TODO perf - // TODO: add a check if the map is sorted based on indexes before reversing. // TODO: add a setup such that on construction the federated map is already sorted. - FederatedRange[] fedRanges = getFederatedRanges(); - - for(int i = 0; i < Math.floor(fedRanges.length / 2.0); i++) { - FederatedData data1 = getFederatedData(fedRanges[i]); - FederatedData data2 = getFederatedData(fedRanges[fedRanges.length-1-i]); - - removeFederatedData(fedRanges[i]); - removeFederatedData(fedRanges[fedRanges.length-1-i]); - - _fedMap.add(Pair.of(fedRanges[i], data2)); - _fedMap.add(Pair.of(fedRanges[fedRanges.length-1-i], data1)); + if(this.getType() != FType.ROW) + throw new DMLRuntimeException("Reversing is only supported for row partitioned federation maps yet."); + + this.sortFederatedRanges(); + + Collections.reverse(this._fedMap); + + int dim = (getType() == FType.COL) ? 1 : 0; + int currentDimPos = 0; + Iterator> fmIter = this._fedMap.iterator(); + while(fmIter.hasNext()) { + Pair elem = fmIter.next(); + long dimSize = elem.getLeft().getSize(dim); + long[] beginDims = elem.getLeft().getBeginDims(); + long[] endDims = elem.getLeft().getEndDims(); + beginDims[dim] = currentDimPos; + currentDimPos += dimSize; + endDims[dim] = currentDimPos; } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java index e4b7ed5e24a..7fe88228d06 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java @@ -39,8 +39,6 @@ @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe public class FederatedRevTest extends AutomatedTestBase { - // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName()); - private final static String TEST_NAME = "FederatedRevTest"; private final static String TEST_DIR = "functions/federated/"; @@ -86,11 +84,25 @@ public void federatedCompilationRevSP() { runRevTest(Types.ExecMode.SPARK, true); } + @Test + public void testRevDifferentRangesCP() { + runRevTest(Types.ExecMode.SINGLE_NODE, false, true); + } + + @Test + public void testRevDifferentRangesSP() { + runRevTest(Types.ExecMode.SPARK, false, true); + } + private void runRevTest(ExecMode execMode) { runRevTest(execMode, false); } private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { + runRevTest(execMode, activateFedCompilation, false); + } + + private void runRevTest(ExecMode execMode, boolean activateFedCompilation, boolean differentPartitionSizes) { boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; ExecMode platformOld = rtplatform; @@ -108,20 +120,52 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { c = cols; } - double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); - double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); - double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); - double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + int r_X1 = r; int r_X2 = r; int r_X3 = r; int r_X4 = r; + int rend_X1 = r; int rend_X2 = r; int rend_X3 = r; int rend_X4 = r; + int c_X1 = c; int c_X2 = c; int c_X3 = c; int c_X4 = c; + int cend_X1 = c_X1; int cend_X2 = c_X1+c_X2; int cend_X3 = cend_X2+c_X3; int cend_X4 = cend_X3+c_X4; + if(rowPartitioned) { + if(differentPartitionSizes) { + r_X1 = r+1; + r_X2 = r-2; + r_X3 = r+1; + r_X4 = r-0; + } + else { + r_X1 = r; + r_X2 = r; + r_X3 = r; + r_X4 = r; + } + rend_X1 = r_X1; rend_X2 = r_X1+r_X2; rend_X3 = rend_X2+r_X3; rend_X4 = rend_X3+r_X4; + c_X1 = c; c_X2 = c; c_X3 = c; c_X4 = c; + cend_X1 = c; cend_X2 = c; cend_X3 = c; cend_X4 = c; + } + else if(differentPartitionSizes) { + c_X1 = c+1; + c_X2 = c-2; + c_X3 = c+1; + c_X4 = c-0; + cend_X1 = c_X1; cend_X2 = c_X1+c_X2; cend_X3 = cend_X2+c_X3; cend_X4 = cend_X3+c_X4; + } + + double[][] X1 = getRandomMatrix(r_X1, c_X1, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r_X2, c_X2, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r_X3, c_X3, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r_X4, c_X4, 1, 5, 1, 9); for(int k : new int[] {1, 2, 3}) { Arrays.fill(X3[k], 0); } - MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); - writeInputMatrixWithMTD("X1", X1, false, mc); - writeInputMatrixWithMTD("X2", X2, false, mc); - writeInputMatrixWithMTD("X3", X3, false, mc); - writeInputMatrixWithMTD("X4", X4, false, mc); + writeInputMatrixWithMTD("X1", X1, false, + new MatrixCharacteristics(r_X1, c_X1, blocksize, r_X1 * c_X1)); + writeInputMatrixWithMTD("X2", X2, false, + new MatrixCharacteristics(r_X2, c_X2, blocksize, r_X2 * c_X2)); + writeInputMatrixWithMTD("X3", X3, false, + new MatrixCharacteristics(r_X3, c_X3, blocksize, r_X3 * c_X3)); + writeInputMatrixWithMTD("X4", X4, false, + new MatrixCharacteristics(r_X4, c_X4, blocksize, r_X4 * c_X4)); // empty script name because we don't execute any script, just start the worker fullDMLScriptName = ""; @@ -134,7 +178,6 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); Process t4 = startLocalFedWorker(port4); - try { if(!isAlive(t1, t2, t3, t4)) throw new RuntimeException("Failed starting federated worker"); @@ -147,7 +190,8 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { // Run reference dml script with normal matrix fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; - programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), + programArgs = new String[] {"-stats", "100", "-args", + input("X1"), input("X2"), input("X3"), input("X4"), Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; runTest(null); @@ -158,8 +202,14 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) { "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), - "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, - "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")}; + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), + "rows=" + rows, "cols=" + cols, + "rend_X1=" + rend_X1, "cend_X1=" + cend_X1, + "rend_X2=" + rend_X2, "cend_X2=" + cend_X2, + "rend_X3=" + rend_X3, "cend_X3=" + cend_X3, + "rend_X4=" + rend_X4, "cend_X4=" + cend_X4, + "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), + "out_S=" + output("S")}; runTest(null); diff --git a/src/test/scripts/functions/federated/FederatedRevTest.dml b/src/test/scripts/functions/federated/FederatedRevTest.dml index d43edd1728b..128b511d5fc 100644 --- a/src/test/scripts/functions/federated/FederatedRevTest.dml +++ b/src/test/scripts/functions/federated/FederatedRevTest.dml @@ -20,12 +20,12 @@ #------------------------------------------------------------- if ($rP) { A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), - ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), - list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + ranges=list(list(0, 0), list($rend_X1, $cols), list($rend_X1, 0), list($rend_X2, $cols), + list($rend_X2, 0), list($rend_X3, $cols), list($rend_X3, 0), list($rend_X4, $cols))); } else { A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), - ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), - list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); + ranges=list(list(0, 0), list($rows, $cend_X1), list(0,$cend_X1), list($rows, $cend_X2), + list(0,$cend_X2), list($rows, $cend_X3), list(0, $cend_X3), list($rows, $cend_X4))); } s = rev(A);