[MINOR] Fix lineage tracing of SAMPLE
diff --git a/scripts/builtin/smote.dml b/scripts/builtin/smote.dml
index a223227..dd096ed 100644
--- a/scripts/builtin/smote.dml
+++ b/scripts/builtin/smote.dml
@@ -68,7 +68,11 @@
synthetic_samples = matrix(0, iterLim*ncol(knn_index), ncol(X))
# shuffle the nn indexes
- rand_index = ifelse(k < iterLim, sample(k, iterLim, TRUE, 42), sample(k, iterLim, 42))
+ #rand_index = ifelse(k < iterLim, sample(k, iterLim, TRUE, 42), sample(k, iterLim, 42))
+ if (k < iterLim)
+ rand_index = sample(k, iterLim, TRUE, 42);
+ else
+ rand_index = sample(k, iterLim, 42);
while(iter < iterLim)
{
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index dc7a5f1..af74498 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -409,8 +409,11 @@
}
//replace output variable name with a placeholder
tmpInstStr = InstructionUtils.replaceOperandName(tmpInstStr);
- tmpInstStr = replaceNonLiteral(tmpInstStr, rows, 2, ec);
- tmpInstStr = replaceNonLiteral(tmpInstStr, cols, 3, ec);
+ tmpInstStr = method.name().equalsIgnoreCase("rand") ?
+ replaceNonLiteral(tmpInstStr, rows, 2, ec) :
+ replaceNonLiteral(tmpInstStr, rows, 3, ec);
+ tmpInstStr = method.name().equalsIgnoreCase("rand") ?
+ replaceNonLiteral(tmpInstStr, cols, 3, ec) : tmpInstStr;
break;
}
case SEQ: {