blob: d546c307c0d136e8b558d35ede21019e10e2519d [file] [log] [blame]
/****************************************************************
* Licensed to the Apache Software Foundation (ASF) under one *
* or more contributor license agreements. See the NOTICE file *
* distributed with this work for additional information *
* regarding copyright ownership. The ASF licenses this file *
* to you under the Apache License, Version 2.0 (the *
* "License"); you may not use this file except in compliance *
* with the License. You may obtain a copy of the License at *
* *
* http://www.apache.org/licenses/LICENSE-2.0 *
* *
* Unless required by applicable law or agreed to in writing, *
* software distributed under the License is distributed on an *
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY *
* KIND, either express or implied. See the License for the *
* specific language governing permissions and limitations *
* under the License. *
****************************************************************/
package org.apache.james.rspamd.task;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Date;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.james.mailbox.MailboxManager;
import org.apache.james.mailbox.MessageIdManager;
import org.apache.james.mailbox.model.MessageResult;
import org.apache.james.mailbox.store.MailboxSessionMapperFactory;
import org.apache.james.rspamd.client.RSpamDHttpClient;
import org.apache.james.task.Task;
import org.apache.james.task.TaskExecutionDetails;
import org.apache.james.task.TaskType;
import org.apache.james.user.api.UsersRepository;
import org.apache.james.user.api.UsersRepositoryException;
import org.apache.james.util.ReactorUtils;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.github.fge.lambdas.Throwing;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import reactor.core.publisher.Mono;
public class FeedSpamToRSpamDTask implements Task {
public static final String SPAM_MAILBOX_NAME = "Spam";
public static final TaskType TASK_TYPE = TaskType.of("FeedSpamToRSpamDTask");
public static class RunningOptions {
public static final Optional<Long> DEFAULT_PERIOD = Optional.empty();
public static final int DEFAULT_MESSAGES_PER_SECOND = 10;
public static final double DEFAULT_SAMPLING_PROBABILITY = 1;
public static final RunningOptions DEFAULT = new RunningOptions(DEFAULT_PERIOD, DEFAULT_MESSAGES_PER_SECOND,
DEFAULT_SAMPLING_PROBABILITY);
private final Optional<Long> periodInSecond;
private final int messagesPerSecond;
private final double samplingProbability;
public RunningOptions(@JsonProperty("periodInSecond") Optional<Long> periodInSecond,
@JsonProperty("messagesPerSecond") int messagesPerSecond,
@JsonProperty("samplingProbability") double samplingProbability) {
this.periodInSecond = periodInSecond;
this.messagesPerSecond = messagesPerSecond;
this.samplingProbability = samplingProbability;
}
public Optional<Long> getPeriodInSecond() {
return periodInSecond;
}
public int getMessagesPerSecond() {
return messagesPerSecond;
}
public double getSamplingProbability() {
return samplingProbability;
}
}
public static class AdditionalInformation implements TaskExecutionDetails.AdditionalInformation {
private static AdditionalInformation from(Context context) {
Context.Snapshot snapshot = context.snapshot();
return new AdditionalInformation(
Clock.systemUTC().instant(),
snapshot.getSpamMessageCount(),
snapshot.getReportedSpamMessageCount(),
snapshot.getErrorCount(),
snapshot.getMessagesPerSecond(),
snapshot.getPeriod(),
snapshot.getSamplingProbability());
}
private final Instant timestamp;
private final long spamMessageCount;
private final long reportedSpamMessageCount;
private final long errorCount;
private final int messagesPerSecond;
private final Optional<Long> period;
private final double samplingProbability;
public AdditionalInformation(Instant timestamp, long spamMessageCount, long reportedSpamMessageCount, long errorCount, int messagesPerSecond, Optional<Long> period, double samplingProbability) {
this.timestamp = timestamp;
this.spamMessageCount = spamMessageCount;
this.reportedSpamMessageCount = reportedSpamMessageCount;
this.errorCount = errorCount;
this.messagesPerSecond = messagesPerSecond;
this.period = period;
this.samplingProbability = samplingProbability;
}
public long getSpamMessageCount() {
return spamMessageCount;
}
public long getReportedSpamMessageCount() {
return reportedSpamMessageCount;
}
public long getErrorCount() {
return errorCount;
}
public int getMessagesPerSecond() {
return messagesPerSecond;
}
public Optional<Long> getPeriod() {
return period;
}
public double getSamplingProbability() {
return samplingProbability;
}
@Override
public Instant timestamp() {
return timestamp;
}
}
public static class Context {
public static class Snapshot {
public static Builder builder() {
return new Builder();
}
static class Builder {
private Optional<Long> spamMessageCount;
private Optional<Long> reportedSpamMessageCount;
private Optional<Long> errorCount;
private Optional<Integer> messagesPerSecond;
private Optional<Long> period;
private Optional<Double> samplingProbability;
Builder() {
spamMessageCount = Optional.empty();
reportedSpamMessageCount = Optional.empty();
errorCount = Optional.empty();
messagesPerSecond = Optional.empty();
period = Optional.empty();
samplingProbability = Optional.empty();
}
public Snapshot build() {
return new Snapshot(
spamMessageCount.orElse(0L),
reportedSpamMessageCount.orElse(0L),
errorCount.orElse(0L),
messagesPerSecond.orElse(0),
period,
samplingProbability.orElse(1D));
}
public Builder spamMessageCount(long spamMessageCount) {
this.spamMessageCount = Optional.of(spamMessageCount);
return this;
}
public Builder reportedSpamMessageCount(long reportedSpamMessageCount) {
this.reportedSpamMessageCount = Optional.of(reportedSpamMessageCount);
return this;
}
public Builder errorCount(long errorCount) {
this.errorCount = Optional.of(errorCount);
return this;
}
public Builder messagesPerSecond(int messagesPerSecond) {
this.messagesPerSecond = Optional.of(messagesPerSecond);
return this;
}
public Builder period(Optional<Long> period) {
this.period = period;
return this;
}
public Builder samplingProbability(double samplingProbability) {
this.samplingProbability = Optional.of(samplingProbability);
return this;
}
}
private final long spamMessageCount;
private final long reportedSpamMessageCount;
private final long errorCount;
private final int messagesPerSecond;
private final Optional<Long> period;
private final double samplingProbability;
public Snapshot(long spamMessageCount, long reportedSpamMessageCount, long errorCount, int messagesPerSecond, Optional<Long> period,
double samplingProbability) {
this.spamMessageCount = spamMessageCount;
this.reportedSpamMessageCount = reportedSpamMessageCount;
this.errorCount = errorCount;
this.messagesPerSecond = messagesPerSecond;
this.period = period;
this.samplingProbability = samplingProbability;
}
public long getSpamMessageCount() {
return spamMessageCount;
}
public long getReportedSpamMessageCount() {
return reportedSpamMessageCount;
}
public long getErrorCount() {
return errorCount;
}
public int getMessagesPerSecond() {
return messagesPerSecond;
}
public Optional<Long> getPeriod() {
return period;
}
public double getSamplingProbability() {
return samplingProbability;
}
@Override
public final boolean equals(Object o) {
if (o instanceof Snapshot) {
Snapshot snapshot = (Snapshot) o;
return Objects.equals(this.spamMessageCount, snapshot.spamMessageCount)
&& Objects.equals(this.reportedSpamMessageCount, snapshot.reportedSpamMessageCount)
&& Objects.equals(this.errorCount, snapshot.errorCount)
&& Objects.equals(this.messagesPerSecond, snapshot.messagesPerSecond)
&& Objects.equals(this.samplingProbability, snapshot.samplingProbability)
&& Objects.equals(this.period, snapshot.period);
}
return false;
}
@Override
public final int hashCode() {
return Objects.hash(spamMessageCount, reportedSpamMessageCount, errorCount, messagesPerSecond, period, samplingProbability);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("spamMessageCount", spamMessageCount)
.add("reportedSpamMessageCount", reportedSpamMessageCount)
.add("errorCount", errorCount)
.add("messagesPerSecond", messagesPerSecond)
.add("period", period)
.add("samplingProbability", samplingProbability)
.toString();
}
}
private final AtomicLong spamMessageCount;
private final AtomicLong reportedSpamMessageCount;
private final AtomicLong errorCount;
private final Integer messagesPerSecond;
private final Optional<Long> period;
private final Double samplingProbability;
public Context(RunningOptions runningOptions) {
this.spamMessageCount = new AtomicLong();
this.reportedSpamMessageCount = new AtomicLong();
this.errorCount = new AtomicLong();
this.messagesPerSecond = runningOptions.messagesPerSecond;
this.period = runningOptions.periodInSecond;
this.samplingProbability = runningOptions.samplingProbability;
}
public void incrementSpamMessageCount() {
spamMessageCount.incrementAndGet();
}
public void incrementReportedSpamMessageCount(int count) {
reportedSpamMessageCount.addAndGet(count);
}
public void incrementErrorCount() {
errorCount.incrementAndGet();
}
public Snapshot snapshot() {
return Snapshot.builder()
.spamMessageCount(spamMessageCount.get())
.reportedSpamMessageCount(reportedSpamMessageCount.get())
.errorCount(errorCount.get())
.messagesPerSecond(messagesPerSecond)
.period(period)
.samplingProbability(samplingProbability)
.build();
}
}
private final GetMailboxMessagesService messagesService;
private final RSpamDHttpClient rSpamDHttpClient;
private final RunningOptions runningOptions;
private final Context context;
private final Clock clock;
public FeedSpamToRSpamDTask(MailboxManager mailboxManager, UsersRepository usersRepository, MessageIdManager messageIdManager, MailboxSessionMapperFactory mapperFactory,
RSpamDHttpClient rSpamDHttpClient, RunningOptions runningOptions, Clock clock) {
this.runningOptions = runningOptions;
this.messagesService = new GetMailboxMessagesService(mailboxManager, usersRepository, mapperFactory, messageIdManager);
this.rSpamDHttpClient = rSpamDHttpClient;
this.context = new Context(runningOptions);
this.clock = clock;
}
@Override
public Result run() {
Optional<Date> afterDate = runningOptions.periodInSecond.map(periodInSecond -> Date.from(clock.instant().minusSeconds(periodInSecond)));
try {
return messagesService.getMailboxMessagesOfAllUser(SPAM_MAILBOX_NAME, afterDate, runningOptions.getSamplingProbability(), context)
.transform(ReactorUtils.<MessageResult, Task.Result>throttle()
.elements(runningOptions.messagesPerSecond)
.per(Duration.ofSeconds(1))
.forOperation(messageResult -> Mono.fromSupplier(Throwing.supplier(() -> rSpamDHttpClient.reportAsSpam(messageResult.getFullContent().getInputStream())))
.then(Mono.fromCallable(() -> {
context.incrementReportedSpamMessageCount(1);
return Result.COMPLETED;
}))
.onErrorResume(error -> {
LOGGER.error("Error when report spam message to RSpamD", error);
context.incrementErrorCount();
return Mono.just(Result.PARTIAL);
})))
.reduce(Task::combine)
.switchIfEmpty(Mono.just(Result.COMPLETED))
.block();
} catch (UsersRepositoryException e) {
LOGGER.error("Error while accessing users from repository", e);
return Task.Result.PARTIAL;
}
}
@Override
public TaskType type() {
return TASK_TYPE;
}
@Override
public Optional<TaskExecutionDetails.AdditionalInformation> details() {
return Optional.of(AdditionalInformation.from(context));
}
@VisibleForTesting
public Context.Snapshot snapshot() {
return context.snapshot();
}
}