001/*
002 * Copyright 2013 Atteo.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.atteo.moonshine.websocket.jsonmessages;
018
019import java.io.IOException;
020import java.lang.reflect.InvocationHandler;
021import java.lang.reflect.InvocationTargetException;
022import java.lang.reflect.Method;
023import java.lang.reflect.Proxy;
024import java.util.ArrayList;
025import java.util.List;
026
027import javax.inject.Provider;
028import javax.websocket.OnMessage;
029import javax.websocket.Session;
030
031import com.fasterxml.jackson.core.JsonProcessingException;
032import com.fasterxml.jackson.databind.ObjectMapper;
033import static com.google.common.base.Preconditions.checkState;
034
035public class HandlerDispatcher {
036    private final List<OnMessageMethodMetadata> onMessageMethods = new ArrayList<>();
037    private final ObjectMapper encoderObjectMapper = new ObjectMapper();
038    private final ObjectMapper decoderObjectMapper = new ObjectMapper();
039
040    public <T> void addHandler(Class<T> klass, Provider<? extends T> provider) {
041        for (Method method : klass.getMethods()) {
042            if (method.isAnnotationPresent(OnMessage.class)) {
043                registerOnMessageMethod(method, provider);
044            }
045        }
046    }
047
048    public <T> void addHandler(final T handler) {
049        this.addHandler((Class<T>)handler.getClass(), new Provider<T>() {
050            @Override
051            public T get() {
052                return handler;
053            }
054        });
055    }
056
057    public <T> SenderProvider<T> addSender(Class<T> klass) {
058        checkState(klass.isInterface(), "Provided Class object must represent an interface");
059
060        for (Method method : klass.getMethods()) {
061            registerSenderMethod(method);
062        }
063        @SuppressWarnings("unchecked")
064        final Class<T> proxyClass = (Class<T>) Proxy.getProxyClass(Thread.currentThread().getContextClassLoader(),
065                klass);
066
067        class SenderInvocationHandler implements InvocationHandler {
068            private final Session session;
069
070            public SenderInvocationHandler(Session session) {
071                this.session = session;
072            }
073
074            @Override
075            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
076                String request = encoderObjectMapper.writeValueAsString(args[0]);
077                session.getBasicRemote().sendText(request);
078                return null;
079            }
080        }
081
082        return new SenderProvider<T>() {
083            @Override
084            public T get(Session session) {
085                try {
086                    return proxyClass.getConstructor(new Class<?>[] { InvocationHandler.class }).newInstance(
087                            new SenderInvocationHandler(session));
088                } catch (NoSuchMethodException | SecurityException | InstantiationException | IllegalAccessException
089                        | IllegalArgumentException | InvocationTargetException ex) {
090                    throw new RuntimeException(ex);
091                }
092            }
093        };
094    }
095
096    private void registerOnMessageMethod(Method method, Provider<?> provider) {
097        Class<?>[] parameterTypes = method.getParameterTypes();
098        if (parameterTypes.length != 1) {
099            throw new RuntimeException("Method marked with @" + OnMessage.class.getSimpleName() +
100                    " must have exactly one argument whose super class is " + JsonMessage.class.getSimpleName());
101        }
102        Class<?> parameterType = parameterTypes[0];
103        if (!JsonMessage.class.isAssignableFrom(parameterType)) {
104            throw new RuntimeException("Method marked with @" + OnMessage.class.getSimpleName() +
105                    " must have exactly one argument whose super class is " + JsonMessage.class.getSimpleName());
106        }
107        decoderObjectMapper.registerSubtypes(parameterType);
108        Class<?> returnType = method.getReturnType();
109
110        if (returnType != Void.TYPE) {
111            encoderObjectMapper.registerSubtypes(returnType);
112        }
113        onMessageMethods.add(new OnMessageMethodMetadata(parameterType, provider, method));
114    }
115
116    private void registerSenderMethod(Method method) {
117        Class<?>[] parameterTypes = method.getParameterTypes();
118        if (parameterTypes.length != 1) {
119            throw new RuntimeException("Sender method" +
120                    " must have exactly one argument whose super class is " + JsonMessage.class.getSimpleName());
121        }
122        Class<?> parameterType = parameterTypes[0];
123        if (!JsonMessage.class.isAssignableFrom(parameterType)) {
124            throw new RuntimeException("Sender method" +
125                    " must have exactly one argument whose super class is " + JsonMessage.class.getSimpleName());
126        }
127        encoderObjectMapper.registerSubtypes(parameterType);
128        Class<?> returnType = method.getReturnType();
129        if (returnType != Void.TYPE) {
130            throw new RuntimeException("Sender method must have " + Void.class.getSimpleName() + " return type");
131        }
132    }
133
134    public String callOnMessage(String message) throws JsonProcessingException, IOException {
135        JsonMessage request = decoderObjectMapper.readValue(message, JsonMessage.class);
136        for (OnMessageMethodMetadata metadata : onMessageMethods) {
137            if (metadata.getMessageType().isAssignableFrom(request.getClass())) {
138                JsonMessage response = metadata.call(request);
139                if (response == null) {
140                    return null;
141                } else {
142                    return encoderObjectMapper.writeValueAsString(response);
143                }
144            }
145        }
146        throw new RuntimeException("Unknown message type: " + request.getClass().getName());
147    }
148
149    private static class OnMessageMethodMetadata {
150        private final Provider<?> provider;
151        private final Class<?> messageType;
152        private final Method method;
153
154        public OnMessageMethodMetadata(Class<?> messageType, Provider<?> provider, Method method) {
155            super();
156            this.provider = provider;
157            this.messageType = messageType;
158            this.method = method;
159        }
160
161        public Class<?> getMessageType() {
162            return messageType;
163        }
164
165        public JsonMessage call(JsonMessage message) {
166            try {
167                Object handler = provider.get();
168                return (JsonMessage) method.invoke(handler, message);
169            } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) {
170                throw new RuntimeException(ex);
171            }
172        }
173    }
174}