Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
Expand Down Expand Up @@ -724,21 +726,41 @@ private static void setThreadID(long tid, FederatedRequest[]... frsets) {
Arrays.stream(frset).forEach(fr -> fr.setTID(tid));
}

/**
* Sort the entries of the federation map based on their federated ranges
*/
private void sortFederatedRanges() {
int dim = (this.getType() == FType.COL) ? 1 : 0;

this._fedMap.sort(new Comparator<Pair<FederatedRange, FederatedData>>() {
@Override
public int compare(Pair<FederatedRange, FederatedData> o1, Pair<FederatedRange, FederatedData> o2) {
return o1.getLeft().getBeginDimsInt()[dim] - o2.getLeft().getBeginDimsInt()[dim];
}
});
}

public void reverseFedMap() {
// TODO perf
// TODO: add a check if the map is sorted based on indexes before reversing.
// TODO: add a setup such that on construction the federated map is already sorted.
FederatedRange[] fedRanges = getFederatedRanges();

for(int i = 0; i < Math.floor(fedRanges.length / 2.0); i++) {
FederatedData data1 = getFederatedData(fedRanges[i]);
FederatedData data2 = getFederatedData(fedRanges[fedRanges.length-1-i]);

removeFederatedData(fedRanges[i]);
removeFederatedData(fedRanges[fedRanges.length-1-i]);

_fedMap.add(Pair.of(fedRanges[i], data2));
_fedMap.add(Pair.of(fedRanges[fedRanges.length-1-i], data1));
if(this.getType() != FType.ROW)
throw new DMLRuntimeException("Reversing is only supported for row partitioned federation maps yet.");

this.sortFederatedRanges();

Collections.reverse(this._fedMap);

int dim = (getType() == FType.COL) ? 1 : 0;
int currentDimPos = 0;
Iterator<Pair<FederatedRange, FederatedData>> fmIter = this._fedMap.iterator();
while(fmIter.hasNext()) {
Pair<FederatedRange, FederatedData> elem = fmIter.next();
long dimSize = elem.getLeft().getSize(dim);
long[] beginDims = elem.getLeft().getBeginDims();
long[] endDims = elem.getLeft().getEndDims();
beginDims[dim] = currentDimPos;
currentDimPos += dimSize;
endDims[dim] = currentDimPos;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedRevTest extends AutomatedTestBase {
// private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName());

private final static String TEST_NAME = "FederatedRevTest";

private final static String TEST_DIR = "functions/federated/";
Expand Down Expand Up @@ -86,11 +84,25 @@ public void federatedCompilationRevSP() {
runRevTest(Types.ExecMode.SPARK, true);
}

@Test
public void testRevDifferentRangesCP() {
runRevTest(Types.ExecMode.SINGLE_NODE, false, true);
}

@Test
public void testRevDifferentRangesSP() {
runRevTest(Types.ExecMode.SPARK, false, true);
}

private void runRevTest(ExecMode execMode) {
runRevTest(execMode, false);
}

private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {
runRevTest(execMode, activateFedCompilation, false);
}

private void runRevTest(ExecMode execMode, boolean activateFedCompilation, boolean differentPartitionSizes) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;

Expand All @@ -108,20 +120,52 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {
c = cols;
}

double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
int r_X1 = r; int r_X2 = r; int r_X3 = r; int r_X4 = r;
int rend_X1 = r; int rend_X2 = r; int rend_X3 = r; int rend_X4 = r;
int c_X1 = c; int c_X2 = c; int c_X3 = c; int c_X4 = c;
int cend_X1 = c_X1; int cend_X2 = c_X1+c_X2; int cend_X3 = cend_X2+c_X3; int cend_X4 = cend_X3+c_X4;
if(rowPartitioned) {
if(differentPartitionSizes) {
r_X1 = r+1;
r_X2 = r-2;
r_X3 = r+1;
r_X4 = r-0;
}
else {
r_X1 = r;
r_X2 = r;
r_X3 = r;
r_X4 = r;
}
rend_X1 = r_X1; rend_X2 = r_X1+r_X2; rend_X3 = rend_X2+r_X3; rend_X4 = rend_X3+r_X4;
c_X1 = c; c_X2 = c; c_X3 = c; c_X4 = c;
cend_X1 = c; cend_X2 = c; cend_X3 = c; cend_X4 = c;
}
else if(differentPartitionSizes) {
c_X1 = c+1;
c_X2 = c-2;
c_X3 = c+1;
c_X4 = c-0;
cend_X1 = c_X1; cend_X2 = c_X1+c_X2; cend_X3 = cend_X2+c_X3; cend_X4 = cend_X3+c_X4;
}

double[][] X1 = getRandomMatrix(r_X1, c_X1, 1, 5, 1, 3);
double[][] X2 = getRandomMatrix(r_X2, c_X2, 1, 5, 1, 7);
double[][] X3 = getRandomMatrix(r_X3, c_X3, 1, 5, 1, 8);
double[][] X4 = getRandomMatrix(r_X4, c_X4, 1, 5, 1, 9);

for(int k : new int[] {1, 2, 3}) {
Arrays.fill(X3[k], 0);
}

MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
writeInputMatrixWithMTD("X1", X1, false, mc);
writeInputMatrixWithMTD("X2", X2, false, mc);
writeInputMatrixWithMTD("X3", X3, false, mc);
writeInputMatrixWithMTD("X4", X4, false, mc);
writeInputMatrixWithMTD("X1", X1, false,
new MatrixCharacteristics(r_X1, c_X1, blocksize, r_X1 * c_X1));
writeInputMatrixWithMTD("X2", X2, false,
new MatrixCharacteristics(r_X2, c_X2, blocksize, r_X2 * c_X2));
writeInputMatrixWithMTD("X3", X3, false,
new MatrixCharacteristics(r_X3, c_X3, blocksize, r_X3 * c_X3));
writeInputMatrixWithMTD("X4", X4, false,
new MatrixCharacteristics(r_X4, c_X4, blocksize, r_X4 * c_X4));

// empty script name because we don't execute any script, just start the worker
fullDMLScriptName = "";
Expand All @@ -134,7 +178,6 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
Process t4 = startLocalFedWorker(port4);


try {
if(!isAlive(t1, t2, t3, t4))
throw new RuntimeException("Failed starting federated worker");
Expand All @@ -147,7 +190,8 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {

// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};

runTest(null);
Expand All @@ -158,8 +202,14 @@ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
"rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
"rows=" + rows, "cols=" + cols,
"rend_X1=" + rend_X1, "cend_X1=" + cend_X1,
"rend_X2=" + rend_X2, "cend_X2=" + cend_X2,
"rend_X3=" + rend_X3, "cend_X3=" + cend_X3,
"rend_X4=" + rend_X4, "cend_X4=" + cend_X4,
"rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
"out_S=" + output("S")};

runTest(null);

Expand Down
8 changes: 4 additions & 4 deletions src/test/scripts/functions/federated/FederatedRevTest.dml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
#-------------------------------------------------------------
if ($rP) {
A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
ranges=list(list(0, 0), list($rend_X1, $cols), list($rend_X1, 0), list($rend_X2, $cols),
list($rend_X2, 0), list($rend_X3, $cols), list($rend_X3, 0), list($rend_X4, $cols)));
} else {
A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
ranges=list(list(0, 0), list($rows, $cend_X1), list(0,$cend_X1), list($rows, $cend_X2),
list(0,$cend_X2), list($rows, $cend_X3), list(0, $cend_X3), list($rows, $cend_X4)));
}

s = rev(A);
Expand Down
Loading