001/*
002 * Copyright 2013 Atteo.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.atteo.moonshine.websocket;
018
019import java.util.ArrayList;
020import java.util.Collections;
021import java.util.HashMap;
022import java.util.List;
023import java.util.Map;
024
025import javax.inject.Provider;
026import javax.servlet.ServletContext;
027import javax.servlet.ServletContextEvent;
028import javax.servlet.ServletContextListener;
029import javax.websocket.Decoder;
030import javax.websocket.DeploymentException;
031import javax.websocket.Encoder;
032import javax.websocket.Endpoint;
033import javax.websocket.Extension;
034import javax.websocket.server.ServerContainer;
035import javax.websocket.server.ServerEndpointConfig;
036import javax.xml.bind.annotation.XmlElement;
037import javax.xml.bind.annotation.XmlIDREF;
038
039import org.atteo.moonshine.TopLevelService;
040import org.atteo.moonshine.services.ImportService;
041import org.atteo.moonshine.webserver.ServletContainer;
042
043import com.google.inject.Module;
044import com.google.inject.PrivateModule;
045
046/**
047 * WebSocket container.
048 */
049public abstract class WebSocketContainerService extends TopLevelService {
050    @ImportService
051    @XmlIDREF
052    @XmlElement
053    protected ServletContainer servletContainer;
054
055    private final List<EndpointDefinition<?>> endpoints = new ArrayList<>();
056
057    protected <T> EndpointDefinition<T> createEndpointDefinition(Class<T> klass) {
058        return new EndpointDefinition<>(klass);
059    }
060
061    /**
062     * Adds ordinary endpoint.
063     */
064    public <T extends Endpoint> EndpointBuilder<T> addEndpoint(Class<T> klass) {
065        EndpointDefinition<T> definition = createEndpointDefinition(klass);
066        endpoints.add(definition);
067        return definition;
068    }
069
070    /**
071     * Adds annotated endpoint.
072     * @param endpoint
073     */
074    public void addAnnotatedEndpoint(Class<?> endpoint) {
075        EndpointDefinition<?> definition = createEndpointDefinition(endpoint);
076        endpoints.add(definition);
077    }
078
079    /**
080     * Adds annotated endpoint.
081     * @param endpoint
082     */
083    public <T> void addAnnotatedEndpoint(Class<T> endpoint, Provider<? extends T> provider) {
084        addAnnotatedEndpointInternal(endpoint, provider);
085    }
086
087    private <T> void addAnnotatedEndpointInternal(Class<T> endpoint, Provider<? extends T> provider) {
088        @SuppressWarnings("unchecked")
089        EndpointDefinition<T> definition = createEndpointDefinition(endpoint);
090        definition.provider(provider);
091        endpoints.add(definition);
092    }
093
094    @Override
095    public Module configure() {
096        return new PrivateModule() {
097            @Override
098            protected void configure() {
099                servletContainer.addListener(new Listener());
100            }
101        };
102    }
103
104    public static interface EndpointBuilder<T> {
105        EndpointBuilder<T> pattern(String pattern);
106        EndpointBuilder<T> provider(Provider<? extends T> provider);
107        EndpointBuilder<T> addEncoder(Class<? extends Encoder> encoder);
108        EndpointBuilder<T> addDecoder(Class<? extends Decoder> encoder);
109        EndpointBuilder<T> addUserProperty(String key, Object value);
110    }
111
112    private class Listener implements ServletContextListener {
113        @Override
114        public void contextInitialized(ServletContextEvent contextEvent) {
115            ServletContext context = contextEvent.getServletContext();
116            ServerContainer container = (ServerContainer) context.getAttribute(ServerContainer.class.getName());
117            for (EndpointDefinition<?> endpointDefinition : endpoints) {
118                try {
119                    container.addEndpoint(endpointDefinition);
120                } catch (DeploymentException ex) {
121                    throw new RuntimeException(ex);
122                }
123            }
124        }
125
126        @Override
127        public void contextDestroyed(ServletContextEvent sce) {
128        }
129    }
130
131    public static class EndpointDefinition<T> implements ServerEndpointConfig, EndpointBuilder<T> {
132        private final Class<T> endpointClass;
133        private Provider<? extends T> provider;
134        private String pattern;
135        private final List<Class<? extends Encoder>> encoders = new ArrayList<>();
136        private final List<Class<? extends Decoder>> decoders = new ArrayList<>();
137        protected final Map<String, Object> userProperties = new HashMap<>();
138
139        public EndpointDefinition(Class<T> endpointClass) {
140            this.endpointClass = endpointClass;
141        }
142
143        @Override
144        public EndpointBuilder<T> pattern(String pattern) {
145            this.pattern = pattern;
146            return this;
147        }
148
149        @Override
150        public EndpointBuilder<T> provider(Provider<? extends T> provider) {
151            this.provider = provider;
152            return this;
153        }
154
155        @Override
156        public EndpointBuilder<T> addEncoder(Class<? extends Encoder> encoder) {
157            encoders.add(encoder);
158            return this;
159        }
160
161        @Override
162        public EndpointBuilder<T> addDecoder(Class<? extends Decoder> decoder) {
163            decoders.add(decoder);
164            return this;
165        }
166
167        @Override
168        public EndpointBuilder<T> addUserProperty(String key, Object value) {
169            userProperties.put(key, value);
170            return this;
171        }
172
173        @Override
174        public Class<T> getEndpointClass() {
175            return endpointClass;
176        }
177
178        @Override
179        public String getPath() {
180            return pattern;
181        }
182
183        @Override
184        public List<String> getSubprotocols() {
185            return Collections.emptyList();
186        }
187
188        @Override
189        public List<Extension> getExtensions() {
190            return Collections.emptyList();
191        }
192
193        @Override
194        public ServerEndpointConfig.Configurator getConfigurator() {
195            return new Configurator() {
196                @Override
197                public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationException {
198                    if (provider == null) {
199                        return super.getEndpointInstance(endpointClass);
200                    }
201                    return endpointClass.cast(provider.get());
202                }
203            };
204        }
205
206        @Override
207        public List<Class<? extends Encoder>> getEncoders() {
208            return encoders;
209        }
210
211        @Override
212        public List<Class<? extends Decoder>> getDecoders() {
213            return decoders;
214        }
215
216        @Override
217        public Map<String, Object> getUserProperties() {
218            return userProperties;
219        }
220    }
221}