/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.direct.portable.artifact;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactChunk;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactMetadata;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetArtifactRequest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetManifestRequest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetManifestResponse;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.Manifest;
import org.apache.beam.model.jobmanagement.v1.ArtifactRetrievalServiceGrpc;
import org.apache.beam.runners.core.construction.ArtifactServiceStager;
import org.apache.beam.runners.core.construction.ArtifactServiceStager.StagedFile;
import org.apache.beam.runners.fnexecution.GrpcFnServer;
import org.apache.beam.runners.fnexecution.InProcessServerFactory;
import org.apache.beam.runners.fnexecution.ServerFactory;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.stub.StreamObserver;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link LocalFileSystemArtifactRetrievalService}. */
@RunWith(JUnit4.class)
public class LocalFileSystemArtifactRetrievalServiceTest {
  @Rule public TemporaryFolder tmp = new TemporaryFolder();

  private File root;
  private ServerFactory serverFactory = InProcessServerFactory.create();

  private GrpcFnServer<LocalFileSystemArtifactStagerService> stagerServer;

  private GrpcFnServer<LocalFileSystemArtifactRetrievalService> retrievalServer;
  private ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceStub retrievalStub;

  @Before
  public void setup() throws Exception {
    root = tmp.newFolder();
    stagerServer =
        GrpcFnServer.allocatePortAndCreateFor(
            LocalFileSystemArtifactStagerService.forRootDirectory(root), serverFactory);
  }

  @After
  public void teardown() throws Exception {
    stagerServer.close();
    retrievalServer.close();
  }

  @Test
  public void retrieveManifest() throws Exception {
    Map<String, byte[]> artifacts = new HashMap<>();
    artifacts.put("foo", "bar, baz, quux".getBytes(UTF_8));
    artifacts.put("spam", new byte[] {127, -22, 5});
    stageAndCreateRetrievalService(artifacts);

    final AtomicReference<Manifest> returned = new AtomicReference<>();
    final CountDownLatch completed = new CountDownLatch(1);
    retrievalStub.getManifest(
        GetManifestRequest.getDefaultInstance(),
        new StreamObserver<GetManifestResponse>() {
          @Override
          public void onNext(GetManifestResponse value) {
            returned.set(value.getManifest());
          }

          @Override
          public void onError(Throwable t) {
            completed.countDown();
          }

          @Override
          public void onCompleted() {
            completed.countDown();
          }
        });

    completed.await();
    assertThat(returned.get(), not(nullValue()));

    List<String> manifestArtifacts = new ArrayList<>();
    for (ArtifactMetadata artifactMetadata : returned.get().getArtifactList()) {
      manifestArtifacts.add(artifactMetadata.getName());
    }
    assertThat(manifestArtifacts, containsInAnyOrder("foo", "spam"));
  }

  @Test
  public void retrieveArtifact() throws Exception {
    Map<String, byte[]> artifacts = new HashMap<>();
    byte[] fooContents = "bar, baz, quux".getBytes(UTF_8);
    artifacts.put("foo", fooContents);
    byte[] spamContents = {127, -22, 5};
    artifacts.put("spam", spamContents);
    stageAndCreateRetrievalService(artifacts);

    final CountDownLatch completed = new CountDownLatch(2);
    ByteArrayOutputStream returnedFooBytes = new ByteArrayOutputStream();
    retrievalStub.getArtifact(
        GetArtifactRequest.newBuilder().setName("foo").build(),
        new MultimapChunkAppender(returnedFooBytes, completed));
    ByteArrayOutputStream returnedSpamBytes = new ByteArrayOutputStream();
    retrievalStub.getArtifact(
        GetArtifactRequest.newBuilder().setName("spam").build(),
        new MultimapChunkAppender(returnedSpamBytes, completed));

    completed.await();
    assertArrayEquals(fooContents, returnedFooBytes.toByteArray());
    assertArrayEquals(spamContents, returnedSpamBytes.toByteArray());
  }

  @Test
  public void retrieveArtifactNotPresent() throws Exception {
    stageAndCreateRetrievalService(
        Collections.singletonMap("foo", "bar, baz, quux".getBytes(UTF_8)));

    final CountDownLatch completed = new CountDownLatch(1);
    final AtomicReference<Throwable> thrown = new AtomicReference<>();
    retrievalStub.getArtifact(
        GetArtifactRequest.newBuilder().setName("spam").build(),
        new StreamObserver<ArtifactChunk>() {
          @Override
          public void onNext(ArtifactChunk value) {
            fail(
                "Should never receive an "
                    + ArtifactChunk.class.getSimpleName()
                    + " for a nonexistent artifact");
          }

          @Override
          public void onError(Throwable t) {
            thrown.set(t);
            completed.countDown();
          }

          @Override
          public void onCompleted() {
            completed.countDown();
          }
        });

    completed.await();
    assertThat(thrown.get(), not(nullValue()));
    assertThat(thrown.get().getMessage(), containsString("No such artifact"));
    assertThat(thrown.get().getMessage(), containsString("spam"));
  }

  private void stageAndCreateRetrievalService(Map<String, byte[]> artifacts) throws Exception {
    List<StagedFile> artifactFiles = new ArrayList<>();
    for (Map.Entry<String, byte[]> artifact : artifacts.entrySet()) {
      File artifactFile = tmp.newFile(artifact.getKey());
      Files.write(artifactFile.toPath(), artifact.getValue());
      artifactFiles.add(StagedFile.of(artifactFile, artifactFile.getName()));
    }
    String stagingSessionToken = "token";

    ArtifactServiceStager stager =
        ArtifactServiceStager.overChannel(
            InProcessChannelBuilder.forName(stagerServer.getApiServiceDescriptor().getUrl())
                .build());
    stager.stage(stagingSessionToken, artifactFiles);

    retrievalServer =
        GrpcFnServer.allocatePortAndCreateFor(
            LocalFileSystemArtifactRetrievalService.forRootDirectory(root), serverFactory);
    retrievalStub =
        ArtifactRetrievalServiceGrpc.newStub(
            InProcessChannelBuilder.forName(retrievalServer.getApiServiceDescriptor().getUrl())
                .build());
  }

  private static class MultimapChunkAppender implements StreamObserver<ArtifactChunk> {
    private final ByteArrayOutputStream target;
    private final CountDownLatch completed;

    private MultimapChunkAppender(ByteArrayOutputStream target, CountDownLatch completed) {
      this.target = target;
      this.completed = completed;
    }

    @Override
    public void onNext(ArtifactChunk value) {
      try {
        target.write(value.getData().toByteArray());
      } catch (IOException e) {
        // This should never happen
        throw new AssertionError(e);
      }
    }

    @Override
    public void onError(Throwable t) {
      completed.countDown();
    }

    @Override
    public void onCompleted() {
      completed.countDown();
    }
  }
}
