[MINOR] minor fixes in smote
replace the rbind/cbind with indexing
rand call is updated with a seed value
diff --git a/scripts/builtin/smote.dml b/scripts/builtin/smote.dml
index 857120e..8e04e1d 100644
--- a/scripts/builtin/smote.dml
+++ b/scripts/builtin/smote.dml
@@ -46,41 +46,54 @@
print("the number of samples should be an integral multiple of 100. Setting s = 100")
s = 100
}
+
+ if(k < 1) {
+ print("k should not be less than 1. Setting k value to default k = 1.")
+ k = 1
+ }
+
# matrix to keep the index of KNN for each minority sample
- knn_index = matrix(0,k,0)
+ knn_index = matrix(0,k,nrow(X))
# find nearest neighbour
for(i in 1:nrow(X))
{
knn = nn(X, X[i, ], k)
- knn_index = cbind(knn_index, knn)
+ knn_index[, i] = knn
}
# number of synthetic samples from each minority class sample
- iter = (s/100)
+ iter = 0
+ iterLim = (s/100)
# matrix to store synthetic samples
- synthetic_samples = matrix(0, 0, ncol(X))
- while(iter > 0)
+ synthetic_samples = matrix(0, iterLim*ncol(knn_index), ncol(X))
+
+ # shuffle the nn indexes
+ if(k < iterLim)
+ rand_index = sample(k, iterLim, TRUE)
+ else
+ rand_index = sample(k, iterLim)
+
+ while(iter < iterLim)
{
- # generate a random number
- # TODO avoid duplicate random numbers
- rand_index = as.integer(as.scalar(Rand(rows=1, cols=1, min=1, max=k)))
# pick the random NN
- knn_sample = knn_index[rand_index,]
+ knn_sample = knn_index[as.scalar(rand_index[iter+1]),]
# generate sample
for(i in 1:ncol(knn_index))
{
index = as.scalar(knn_sample[1,i])
X_diff = X[index,] - X[i, ]
- gap = as.scalar(Rand(rows=1, cols=1, min=0, max=1))
+ gap = as.scalar(Rand(rows=1, cols=1, min=0, max=1, seed = 41))
X_sys = X[i, ] + (gap*X_diff)
- synthetic_samples = rbind(synthetic_samples, X_sys)
+ synthetic_samples[iter*ncol(knn_index)+i,] = X_sys;
}
- iter = iter - 1
+ iter = iter + 1
}
Y = synthetic_samples
+
if(verbose)
print(nrow(Y)+ " synthesized samples generated.")
+
}
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSmoteTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSmoteTest.java
index 675e741..4eb2fdf 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSmoteTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSmoteTest.java
@@ -48,8 +48,13 @@
}
@Test
+ public void testSmote0CP() {
+ runSmoteTest(100, 1, LopProperties.ExecType.CP);
+ }
+
+ @Test
public void testSmote1CP() {
- runSmoteTest(300, 3, LopProperties.ExecType.CP);
+ runSmoteTest(300, 10, LopProperties.ExecType.CP);
}
@Test