/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.geronimo.microprofile.impl.jwtauth.cdi;

import static java.util.Objects.requireNonNull;
import static java.util.Optional.empty;
import static java.util.Optional.of;
import static java.util.Optional.ofNullable;
import static java.util.function.Function.identity;

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collector;

import javax.enterprise.context.ApplicationScoped;
import javax.enterprise.context.Dependent;
import javax.enterprise.context.RequestScoped;
import javax.enterprise.event.Observes;
import javax.enterprise.inject.Any;
import javax.enterprise.inject.Default;
import javax.enterprise.inject.Instance;
import javax.enterprise.inject.Vetoed;
import javax.enterprise.inject.spi.AfterBeanDiscovery;
import javax.enterprise.inject.spi.AfterDeploymentValidation;
import javax.enterprise.inject.spi.BeforeBeanDiscovery;
import javax.enterprise.inject.spi.Extension;
import javax.enterprise.inject.spi.InjectionPoint;
import javax.enterprise.inject.spi.ProcessInjectionPoint;
import javax.enterprise.util.AnnotationLiteral;
import javax.enterprise.util.Nonbinding;
import javax.inject.Provider;
import javax.json.JsonArray;
import javax.json.JsonArrayBuilder;
import javax.json.JsonNumber;
import javax.json.JsonObject;
import javax.json.JsonString;
import javax.json.JsonValue;
import javax.json.spi.JsonProvider;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;

import org.apache.geronimo.microprofile.impl.jwtauth.config.GeronimoJwtAuthConfig;
import org.apache.geronimo.microprofile.impl.jwtauth.jwt.ContextualJsonWebToken;
import org.apache.geronimo.microprofile.impl.jwtauth.servlet.JwtRequest;
import org.eclipse.microprofile.jwt.Claim;
import org.eclipse.microprofile.jwt.ClaimValue;
import org.eclipse.microprofile.jwt.Claims;
import org.eclipse.microprofile.jwt.JsonWebToken;

public class GeronimoJwtAuthExtension implements Extension {
    private final ThreadLocal<JwtRequest> request = new ThreadLocal<>();

    private final Collection<Injection> injectionPoints = new HashSet<>(8);
    private final Collection<Throwable> errors = new ArrayList<>();
    private JsonProvider json;

    void setClaimMethodsBinding(@Observes final BeforeBeanDiscovery beforeBeanDiscovery) {
        beforeBeanDiscovery.configureQualifier(Claim.class)
                .methods().forEach(m -> m.remove(it -> it.annotationType() == Nonbinding.class));
        json = JsonProvider.provider();
    }

    void captureInjections(@Observes final ProcessInjectionPoint<?, ?> processInjectionPoint) {
        final InjectionPoint injectionPoint = processInjectionPoint.getInjectionPoint();
        ofNullable(injectionPoint.getAnnotated().getAnnotation(Claim.class))
                .flatMap(claim -> createInjection(claim, injectionPoint.getType()))
                .ifPresent(injectionPoints::add);
    }

    void addClaimBeans(@Observes final AfterBeanDiscovery afterBeanDiscovery) {
        // it is another instance than th eone used in our initializer but it should be backed by the same impl
        afterBeanDiscovery.addBean()
                .id(GeronimoJwtAuthExtension.class.getName() + "#" + GeronimoJwtAuthConfig.class.getName())
                .beanClass(GeronimoJwtAuthConfig.class)
                .types(GeronimoJwtAuthConfig.class, Object.class)
                .qualifiers(Default.Literal.INSTANCE, Any.Literal.INSTANCE)
                .scope(ApplicationScoped.class)
                .createWith(ctx -> GeronimoJwtAuthConfig.create());

        afterBeanDiscovery.addBean()
                .id(GeronimoJwtAuthExtension.class.getName() + "#" + JsonWebToken.class.getName())
                .beanClass(JsonWebToken.class)
                .types(JsonWebToken.class, Object.class)
                .qualifiers(Default.Literal.INSTANCE, Any.Literal.INSTANCE)
                .scope(ApplicationScoped.class)
                .createWith(ctx -> new ContextualJsonWebToken(() -> {
                    final JwtRequest request = this.request.get();
                    if (request == null) {
                        throw new IllegalStateException("No JWT in this request");
                    }
                    return request.getToken();
                }));

        injectionPoints.forEach(injection ->
                afterBeanDiscovery.addBean()
                        .id(GeronimoJwtAuthExtension.class.getName() + "#" + injection.getId())
                        .beanClass(injection.findClass())
                        .qualifiers(injection.literal(), Any.Literal.INSTANCE)
                        .scope(injection.findScope())
                        .types(injection.type, Object.class)
                        .createWith(ctx -> injection.createInstance(request.get())));

        injectionPoints.clear();
    }

    void afterDeployment(@Observes final AfterDeploymentValidation afterDeploymentValidation) {
        errors.forEach(afterDeploymentValidation::addDeploymentProblem);
    }

    private Optional<Injection> createInjection(final Claim claim, final Type type) {
        if (ParameterizedType.class.isInstance(type)) {
            final ParameterizedType pt = ParameterizedType.class.cast(type);
            if (pt.getActualTypeArguments().length == 1) {
                final Type raw = pt.getRawType();
                final Type arg = pt.getActualTypeArguments()[0];

                if (raw == Provider.class || raw == Instance.class) {
                    return createInjection(claim, arg);
                }
                if (raw == Optional.class) {
                    return createInjection(claim, arg)
                            .map(it -> new Injection(claim.value(), claim.standard(), type) {
                                @Override
                                Object createInstance(final JwtRequest jwtRequest) {
                                    return ofNullable(it.createInstance(jwtRequest));
                                }
                            });
                }
                if (raw == ClaimValue.class) {
                    final String name = getClaimName(claim);
                    return createInjection(claim, arg)
                            .map(it -> new Injection(claim.value(), claim.standard(), type) {
                                @Override
                                Object createInstance(final JwtRequest jwtRequest) {
                                    return new ClaimValue<Object>() {
                                        @Override
                                        public String getName() {
                                            return name;
                                        }

                                        @Override
                                        public Object getValue() {
                                            return it.createInstance(jwtRequest);
                                        }
                                    };
                                }
                            });
                }
                if (Class.class.isInstance(raw) && Collection.class.isAssignableFrom(Class.class.cast(raw))) {
                    return of(new Injection(claim.value(), claim.standard(), type));
                }
            }
        } else if (Class.class.isInstance(type)) {
            final Class<?> clazz = Class.class.cast(type);
            if (JsonValue.class.isAssignableFrom(clazz)) {
                if (JsonString.class.isAssignableFrom(clazz)) {
                    return of(new Injection(claim.value(), claim.standard(), clazz) {
                        @Override
                        Object createInstance(final JwtRequest jwtRequest) {
                            final Object instance = super.createInstance(jwtRequest);
                            if (JsonString.class.isInstance(instance)) {
                                return instance;
                            }
                            return json.createValue(String.class.cast(instance));
                        }
                    });
                }
                if (JsonNumber.class.isAssignableFrom(clazz)) {
                    return of(new Injection(claim.value(), claim.standard(), clazz) {
                        @Override
                        Object createInstance(final JwtRequest jwtRequest) {
                            final Object instance = super.createInstance(jwtRequest);
                            if (JsonNumber.class.isInstance(instance)) {
                                return instance;
                            }
                            return json.createValue(Number.class.cast(instance).doubleValue());
                        }
                    });
                }
                if (JsonObject.class.isAssignableFrom(clazz)) {
                    return of(new Injection(claim.value(), claim.standard(), clazz));
                }
                if (JsonArray.class.isAssignableFrom(clazz)) {
                    return of(new Injection(claim.value(), claim.standard(), clazz) {
                        @Override
                        Object createInstance(final JwtRequest jwtRequest) {
                            final Object instance = super.createInstance(jwtRequest);
                            if (instance == null) {
                                return null;
                            }
                            if (JsonArray.class.isInstance(instance)) {
                                return instance;
                            }
                            if (Set.class.isInstance(instance)) {
                                return ((Set<String>) instance).stream()
                                        .collect(Collector.of(
                                                json::createArrayBuilder,
                                                JsonArrayBuilder::add,
                                                JsonArrayBuilder::addAll,
                                                JsonArrayBuilder::build));
                            }
                            throw new IllegalArgumentException("Unsupported value: " + instance);
                        }
                    });
                }
            } else {
                final Class<?> objectType = wrapPrimitives(clazz);
                if (CharSequence.class.isAssignableFrom(clazz) || Double.class.isAssignableFrom(objectType) ||
                        Long.class.isAssignableFrom(objectType) || Integer.class.isAssignableFrom(objectType)) {
                    return of(new Injection(claim.value(), claim.standard(), objectType));
                }
            }
        }
        errors.add(new IllegalArgumentException(type + " not supported by JWT-Auth implementation"));
        return empty();
    }

    private Class<?> wrapPrimitives(final Class<?> type) {
        if (long.class == type) {
            return Long.class;
        }
        if (int.class == type) {
            return Integer.class;
        }
        if (double.class == type) {
            return Double.class;
        }
        return type;
    }

    private static String getClaimName(final Claim claim) {
        return getClaimName(claim.value(), claim.standard());
    }

    private static String getClaimName(final String name, final Claims val) {
        return of(name).filter(s -> !s.isEmpty()).orElse(val.name());
    }

    public void execute(final HttpServletRequest req, final ServletRunnable task) {
        try {
            final JwtRequest jwtRequest = requireNonNull(JwtRequest.class.isInstance(req) ?
                            JwtRequest.class.cast(req) : JwtRequest.class.cast(req.getAttribute(JwtRequest.class.getName())),
                    "No JwtRequest");
            execute(jwtRequest, task);
        } catch (final IOException | ServletException e) {
            throw new IllegalStateException(e);
        }
    }

    public void execute(final JwtRequest req, final ServletRunnable task) throws ServletException, IOException {
        request.set(req); // we want to track it ourself to support propagation properly when needed
        try {
            task.run();
        } finally {
            request.remove();
        }
    }

    @FunctionalInterface
    public interface ServletRunnable {
        void run() throws ServletException, IOException;
    }

    private static class Injection {
        private final String name;
        private final Claims claims;
        private final Type type;
        private final int hash;
        private final Function<Object, Object> transformer;
        private final String runtimeName;

        private Injection(final String name, final Claims claims, final Type type) {
            this.name = name;
            this.claims = claims;
            this.type = type;

            Function<Object, Object> transformer;
            try {
                Claims.valueOf(getClaimName(name, claims));
                transformer = identity();
            } catch (final IllegalArgumentException iae) {
                if (type == String.class) {
                    transformer = val -> val == null ? null : JsonString.class.cast(val).getString();
                } else if (type == Long.class) {
                    transformer = val -> val == null ? null : JsonNumber.class.cast(val).longValue();
                } else {
                    transformer = identity();
                }
            }
            this.transformer = transformer;
            this.runtimeName = getClaimName(name, claims);

            {
                int result = name.hashCode();
                result = 31 * result + claims.hashCode();
                hash = 31 * result + type.hashCode();
            }
        }

        private String getId() {
            return name + "/" + claims + "/" + type;
        }

        private Class<?> findClass() {
            if (Class.class.isInstance(type)) {
                return Class.class.cast(type);
            }
            if (ParameterizedType.class.isInstance(type)) {
                ParameterizedType current = ParameterizedType.class.cast(type);
                while (!Class.class.isInstance(current.getRawType())) {
                    current = ParameterizedType.class.cast(current.getRawType());
                }
                return Class.class.cast(current.getRawType());
            }
            throw new IllegalArgumentException("Can't find a class from " + type);
        }

        private Class<? extends Annotation> findScope() {
            if (ClaimValue.class == findClass()) {
                return RequestScoped.class;
            }
            return Dependent.class;
        }

        private Annotation literal() {
            return new ClaimLiteral(name, claims);
        }

        Object createInstance(final JwtRequest jwtRequest) {
            return transformer.apply(jwtRequest.getToken().getClaim(runtimeName));
        }

        @Override
        public boolean equals(final Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            final Injection injection = Injection.class.cast(o);
            return runtimeName.equals(injection.runtimeName) && type.equals(injection.type);
        }

        @Override
        public int hashCode() {
            return hash;
        }

        @Override
        public String toString() {
            return "Injection{claim='" + runtimeName + "', type=" + type + '}';
        }
    }

    @Vetoed
    private static class ClaimLiteral extends AnnotationLiteral<Claim> implements Claim {
        private final String name;
        private final Claims claims;

        private ClaimLiteral(final String name, final Claims claims) {
            this.name = name;
            this.claims = claims;
        }

        @Override
        public String value() {
            return name;
        }

        @Override
        public Claims standard() {
            return claims;
        }
    }
}
