/**
 * Copyright 2005-2013 Restlet S.A.S.
 * 
 * The contents of this file are subject to the terms of one of the following
 * open source licenses: Apache 2.0 or LGPL 3.0 or LGPL 2.1 or CDDL 1.0 or EPL
 * 1.0 (the "Licenses"). You can select the license that you prefer but you may
 * not use this file except in compliance with one of these Licenses.
 * 
 * You can obtain a copy of the Apache 2.0 license at
 * http://www.opensource.org/licenses/apache-2.0
 * 
 * You can obtain a copy of the LGPL 3.0 license at
 * http://www.opensource.org/licenses/lgpl-3.0
 * 
 * You can obtain a copy of the LGPL 2.1 license at
 * http://www.opensource.org/licenses/lgpl-2.1
 * 
 * You can obtain a copy of the CDDL 1.0 license at
 * http://www.opensource.org/licenses/cddl1
 * 
 * You can obtain a copy of the EPL 1.0 license at
 * http://www.opensource.org/licenses/eclipse-1.0
 * 
 * See the Licenses for the specific language governing permissions and
 * limitations under the Licenses.
 * 
 * Alternatively, you can obtain a royalty free commercial license with less
 * limitations, transferable or non-transferable, directly at
 * http://www.restlet.com/products/restlet-framework
 * 
 * Restlet is a registered trademark of Restlet S.A.S.
 */

package org.restlet.ext.openid.internal;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.UUID;
import java.util.logging.Logger;

import org.openid4java.message.AuthSuccess;
import org.openid4java.message.DirectError;
import org.openid4java.message.Message;
import org.openid4java.message.MessageException;
import org.openid4java.message.MessageExtension;
import org.openid4java.message.ParameterList;
import org.openid4java.message.ax.AxMessage;
import org.openid4java.message.ax.FetchRequest;
import org.openid4java.message.ax.FetchResponse;
import org.openid4java.server.ServerManager;
import org.restlet.Context;
import org.restlet.Request;
import org.restlet.Response;
import org.restlet.engine.Engine;
import org.restlet.ext.openid.AttributeExchange;
import org.restlet.ext.openid.internal.ProviderResult.OPR;

/**
 * Describes the OpenID provider, also known as an OP.
 * 
 * @author Martin Svensson
 */
public class Provider {

    public enum OpenIdMode {
        associate, check_authentication, checkid_immediate, checkid_setup, errorMode;
    }

    public static final String OPENID_MODE = "openid.mode";

    public static final String OPENID_RETURNTO = "openid.return_to";

    public static final String OPENID_REALM = "openid.realm";

    private final Map<String, UserSession> sessions = new HashMap<String, UserSession>();

    public Message fetchAttributes(ParameterList pl) throws Exception {
        if (pl == null)
            return null;
        Message m = Message.createMessage(pl);
        if (m.hasExtension(AxMessage.OPENID_NS_AX)) {
            return m;
        }
        return null;
    }

    public Message fetchAttributes(UserSession us) throws Exception {
        return fetchAttributes(us.getParameterList());
    }

    public Logger getLogger() {
        Logger result = null;

        Context context = Context.getCurrent();

        if (context != null) {
            result = context.getLogger();
        }

        if (result == null) {
            result = Engine.getLogger(this, "org.restlet.ext.openid.OP");
        }

        return result;
    }

    public Set<AttributeExchange> getOptionalAttributes(UserSession us)
            throws Exception {
        return getAttributes(us.getParameterList(), false);
    }

    public Set<AttributeExchange> getOptionalAttributes(ParameterList pl)
            throws Exception {
        return getAttributes(pl, false);
    }

    public Set<AttributeExchange> getRequiredAttributes(UserSession us)
            throws Exception {
        return getAttributes(us.getParameterList(), true);
    }

    public Set<AttributeExchange> getRequiredAttributes(ParameterList pl)
            throws Exception {
        return getAttributes(pl, true);
    }

    public Set<AttributeExchange> getAttributes(ParameterList pl,
            boolean required) throws Exception {
        Message m = fetchAttributes(pl);
        if (m == null)
            return null;
        MessageExtension me = m.getExtension(AxMessage.OPENID_NS_AX);
        if (me instanceof FetchRequest) {
            FetchRequest fr = (FetchRequest) me;
            Map<?, ?> attrs = fr.getAttributes(required);
            Set<AttributeExchange> toRet = new TreeSet<AttributeExchange>();
            for (Object key : attrs.keySet()) {
                String type = (String) attrs.get(key);
                AttributeExchange ax = AttributeExchange.valueOfType(type);
                if (ax != null)
                    toRet.add(ax);
            }
            return toRet;
        }
        return null;
    }

    @SuppressWarnings("rawtypes")
    public Map getOptionalAttributes(Message m) throws Exception {
        FetchRequest req = (FetchRequest) m
                .getExtension(AxMessage.OPENID_NS_AX);
        return req.getAttributes(false);
    }

    @SuppressWarnings("rawtypes")
    public Map getRequiredAttributes(Message m) throws Exception {
        FetchRequest req = (FetchRequest) m
                .getExtension(AxMessage.OPENID_NS_AX);
        return req.getAttributes(true);
    }

    public UserSession getSession(String sessionId) {
        return this.sessions.get(sessionId);
    }

    public ProviderResult processOPRequest(ServerManager sm, ParameterList pl,
            Request req, Response res, UserSession us) {
        String modeParam = null;// pl.getParameterValue(OPENID_MODE);
        OpenIdMode mode = null;
        Message response;

        if (pl == null && us != null) {
            pl = us.getParameterList();

        }
        mode = OpenIdMode.valueOf(pl.getParameterValue(OPENID_MODE));

        try {

        } catch (Exception e) {
            Engine.getAnonymousLogger().warning(
                    "No Known openid.mode: " + modeParam);
            mode = OpenIdMode.errorMode;
        }
        Engine.getAnonymousLogger().info("processRequest: " + mode);
        switch (mode) {
        case associate:
            response = sm.associationResponse(pl);
            return new ProviderResult(OPR.OK, response.keyValueFormEncoding());
        case checkid_setup:
        case checkid_immediate:
            if (us == null || us.getUser() == null) { // this means no
                                                      // authorization
                // has taken place yet
                String session = UUID.randomUUID().toString();
                this.sessions.put(session, new UserSession(pl));
                return new ProviderResult(OPR.GET_USER, session);
            }
            OpenIdUser user = us.getUser();
            response = sm.authResponse(pl, user.getClaimedId(),
                    user.getClaimedId(), user.getApproved());
            // add any attributes:

            if (response instanceof DirectError) {
                return new ProviderResult(OPR.OK,
                        response.keyValueFormEncoding());
            }
            if (us.getUser().attributes() != null
                    && us.getUser().attributes().size() > 0) {
                FetchResponse fr = null;
                fr = FetchResponse.createFetchResponse();
                for (AttributeExchange attr : us.getUser().attributes()) {
                    String val = us.getUser().getAXValue(attr);
                    if (val != null) {
                        try {
                            fr.addAttribute(attr.getName(), attr.getSchema(),
                                    val);
                        } catch (MessageException e) {
                            // TODO Auto-generated catch block
                            e.printStackTrace();
                        }
                    }
                }
                if (fr.getAttributes().size() > 0) {
                    try {
                        response.addExtension(fr);
                        sm.sign((AuthSuccess) response);
                    } catch (Exception e) {
                        // TODO Auto-generated catch block
                        e.printStackTrace();
                    }
                }

            }
            res.redirectSeeOther(response.getDestinationUrl(true));
            return new ProviderResult(OPR.OK, "");
        case check_authentication:
            response = sm.verify(pl);
            return new ProviderResult(OPR.OK, response.keyValueFormEncoding());
        case errorMode:
            response = DirectError.createDirectError("Unknown request");
            return new ProviderResult(OPR.OK, response.keyValueFormEncoding());
        }
        return null;
    }

}
