From b978c776b566ea23c33fb2b015b6d62b2e699e65 Mon Sep 17 00:00:00 2001 From: wangyu <727842003@qq.com> Date: Sat, 12 Oct 2024 15:35:08 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=AE=8C=E6=88=90?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../flyfish/boot/cas/config/CASConfig.java | 91 ++++ .../config/session/WebSessionDecorator.java | 194 ++++++++ .../config/session/WebSessionListener.java | 21 + .../boot/cas/controller/IndexController.java | 16 + .../flyfish/boot/cas/filter/CASContext.java | 38 +- .../boot/cas/filter/CASContextInit.java | 13 + .../flyfish/boot/cas/filter/CASFilter.java | 443 ++++++++++-------- .../boot/cas/filter/CASLoginFilter.java | 20 + .../flyfish/boot/cas/filter/CASParameter.java | 5 +- src/main/resources/application.properties | 1 - src/main/resources/application.yml | 13 + 11 files changed, 645 insertions(+), 210 deletions(-) create mode 100644 src/main/java/dev/flyfish/boot/cas/config/CASConfig.java create mode 100644 src/main/java/dev/flyfish/boot/cas/config/session/WebSessionDecorator.java create mode 100644 src/main/java/dev/flyfish/boot/cas/config/session/WebSessionListener.java create mode 100644 src/main/java/dev/flyfish/boot/cas/controller/IndexController.java create mode 100644 src/main/java/dev/flyfish/boot/cas/filter/CASContextInit.java create mode 100644 src/main/java/dev/flyfish/boot/cas/filter/CASLoginFilter.java delete mode 100644 src/main/resources/application.properties create mode 100644 src/main/resources/application.yml diff --git a/src/main/java/dev/flyfish/boot/cas/config/CASConfig.java b/src/main/java/dev/flyfish/boot/cas/config/CASConfig.java new file mode 100644 index 0000000..543263f --- /dev/null +++ b/src/main/java/dev/flyfish/boot/cas/config/CASConfig.java @@ -0,0 +1,91 @@ +package dev.flyfish.boot.cas.config; + +import dev.flyfish.boot.cas.config.session.WebSessionDecorator; +import dev.flyfish.boot.cas.config.session.WebSessionListener; +import dev.flyfish.boot.cas.filter.CASFilter; +import dev.flyfish.boot.cas.filter.CASParameter; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.web.ServerProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.session.DefaultWebSessionManager; +import org.springframework.web.server.session.InMemoryWebSessionStore; +import org.springframework.web.server.session.WebSessionManager; +import org.springframework.web.server.session.WebSessionStore; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.List; + +/** + * cas核心配置 + * + * @author wangyu + */ +@Configuration +public class CASConfig { + + @Bean + @ConfigurationProperties("cas.filter") + public CASParameter casParameter() { + return new CASParameter(); + } + + @Bean + public CASFilter casFilter(CASParameter casParameter) { + return new CASFilter(casParameter); + } + + @Bean + public WebSessionStore webSessionStore(WebSessionManager webSessionManager, ServerProperties serverProperties, + ObjectProvider listeners) { + if (webSessionManager instanceof DefaultWebSessionManager defaultWebSessionManager) { + Duration timeout = serverProperties.getReactive().getSession().getTimeout(); + int maxSessions = serverProperties.getReactive().getSession().getMaxSessions(); + ListenableWebSessionStore sessionStore = new ListenableWebSessionStore(timeout, listeners); + sessionStore.setMaxSessions(maxSessions); + defaultWebSessionManager.setSessionStore(sessionStore); + return sessionStore; + } + throw new IllegalStateException("Cannot find web session manager"); + } + + /** + * 处理session销毁,保证正确退出 + * + * @param casFilter 过滤器 + * @return 结果 + */ + @Bean + public WebSessionListener singleSignOutSessionListener(CASFilter casFilter) { + return new WebSessionListener() { + @Override + public Mono onSessionInvalidated(WebSession session) { + return casFilter.getSessionMappingStorage().removeBySessionById(session.getId()); + } + }; + } + + static final class ListenableWebSessionStore extends InMemoryWebSessionStore { + private final Duration timeout; + private final List listeners; + + private ListenableWebSessionStore(Duration timeout, ObjectProvider listeners) { + this.timeout = timeout; + this.listeners = listeners.stream().toList(); + } + + public Mono createWebSession() { + return super.createWebSession() + .map(session -> (WebSession) new WebSessionDecorator(session, listeners)) + .doOnSuccess(this::setMaxIdleTime); + } + + private void setMaxIdleTime(WebSession session) { + session.setMaxIdleTime(this.timeout); + } + } + +} diff --git a/src/main/java/dev/flyfish/boot/cas/config/session/WebSessionDecorator.java b/src/main/java/dev/flyfish/boot/cas/config/session/WebSessionDecorator.java new file mode 100644 index 0000000..edf4be3 --- /dev/null +++ b/src/main/java/dev/flyfish/boot/cas/config/session/WebSessionDecorator.java @@ -0,0 +1,194 @@ +package dev.flyfish.boot.cas.config.session; + +import lombok.RequiredArgsConstructor; +import org.springframework.lang.Nullable; +import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Map; + +@RequiredArgsConstructor +public class WebSessionDecorator implements WebSession { + + private final WebSession decorated; + + private final List listeners; + + /** + * Return a unique session identifier. + */ + @Override + public String getId() { + return decorated.getId(); + } + + /** + * Return a map that holds session attributes. + */ + @Override + public Map getAttributes() { + return decorated.getAttributes(); + } + + /** + * Return the session attribute value if present. + * + * @param name the attribute name + * @return the attribute value + */ + @Nullable + @Override + public T getAttribute(String name) { + return decorated.getAttribute(name); + } + + /** + * Return the session attribute value or if not present raise an + * {@link IllegalArgumentException}. + * + * @param name the attribute name + * @return the attribute value + */ + @Override + public T getRequiredAttribute(String name) { + return decorated.getRequiredAttribute(name); + } + + /** + * Return the session attribute value, or a default, fallback value. + * + * @param name the attribute name + * @param defaultValue a default value to return instead + * @return the attribute value + */ + @Override + public T getAttributeOrDefault(String name, T defaultValue) { + return decorated.getAttributeOrDefault(name, defaultValue); + } + + /** + * Force the creation of a session causing the session id to be sent when + * {@link #save()} is called. + */ + @Override + public void start() { + decorated.start(); + } + + /** + * Whether a session with the client has been started explicitly via + * {@link #start()} or implicitly by adding session attributes. + * If "false" then the session id is not sent to the client and the + * {@link #save()} method is essentially a no-op. + */ + @Override + public boolean isStarted() { + return decorated.isStarted(); + } + + /** + * Generate a new id for the session and update the underlying session + * storage to reflect the new id. After a successful call {@link #getId()} + * reflects the new session id. + * + * @return completion notification (success or error) + */ + @Override + public Mono changeSessionId() { + return decorated.changeSessionId(); + } + + /** + * Invalidate the current session and clear session storage. + * + * @return completion notification (success or error) + */ + @Override + public Mono invalidate() { + // 后续处理 + Mono consumer = Mono.defer(() -> listeners.stream() + .map(listener -> listener.onSessionInvalidated(this.decorated)) + .reduce(Mono::then) + .orElse(Mono.empty())); + + return decorated.invalidate().then(consumer); + } + + /** + * Save the session through the {@code WebSessionStore} as follows: + *
    + *
  • If the session is new (i.e. created but never persisted), it must have + * been started explicitly via {@link #start()} or implicitly by adding + * attributes, or otherwise this method should have no effect. + *
  • If the session was retrieved through the {@code WebSessionStore}, + * the implementation for this method must check whether the session was + * {@link #invalidate() invalidated} and if so return an error. + *
+ *

Note that this method is not intended for direct use by applications. + * Instead it is automatically invoked just before the response is + * committed. + * + * @return {@code Mono} to indicate completion with success or error + */ + @Override + public Mono save() { + return decorated.save(); + } + + /** + * Return {@code true} if the session expired after {@link #getMaxIdleTime() + * maxIdleTime} elapsed. + *

Typically expiration checks should be automatically made when a session + * is accessed, a new {@code WebSession} instance created if necessary, at + * the start of request processing so that applications don't have to worry + * about expired session by default. + */ + @Override + public boolean isExpired() { + return decorated.isExpired(); + } + + /** + * Return the time when the session was created. + */ + @Override + public Instant getCreationTime() { + return decorated.getCreationTime(); + } + + /** + * Return the last time of session access as a result of user activity such + * as an HTTP request. Together with {@link #getMaxIdleTime() + * maxIdleTimeInSeconds} this helps to determine when a session is + * {@link #isExpired() expired}. + */ + @Override + public Instant getLastAccessTime() { + return decorated.getLastAccessTime(); + } + + /** + * Configure the max amount of time that may elapse after the + * {@link #getLastAccessTime() lastAccessTime} before a session is considered + * expired. A negative value indicates the session should not expire. + * + * @param maxIdleTime + */ + @Override + public void setMaxIdleTime(Duration maxIdleTime) { + decorated.setMaxIdleTime(maxIdleTime); + } + + /** + * Return the maximum time after the {@link #getLastAccessTime() + * lastAccessTime} before a session expires. A negative time indicates the + * session doesn't expire. + */ + @Override + public Duration getMaxIdleTime() { + return decorated.getMaxIdleTime(); + } +} diff --git a/src/main/java/dev/flyfish/boot/cas/config/session/WebSessionListener.java b/src/main/java/dev/flyfish/boot/cas/config/session/WebSessionListener.java new file mode 100644 index 0000000..6a946d9 --- /dev/null +++ b/src/main/java/dev/flyfish/boot/cas/config/session/WebSessionListener.java @@ -0,0 +1,21 @@ +package dev.flyfish.boot.cas.config.session; + +import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; + +/** + * web session监听器 + * + * @author wangyu + * 基于装饰器增强实现,可灵活处理 + */ +public interface WebSessionListener { + + default Mono onSessionCreated(WebSession session) { + return Mono.empty(); + } + + default Mono onSessionInvalidated(WebSession session) { + return Mono.empty(); + } +} diff --git a/src/main/java/dev/flyfish/boot/cas/controller/IndexController.java b/src/main/java/dev/flyfish/boot/cas/controller/IndexController.java new file mode 100644 index 0000000..3e0d3cf --- /dev/null +++ b/src/main/java/dev/flyfish/boot/cas/controller/IndexController.java @@ -0,0 +1,16 @@ +package dev.flyfish.boot.cas.controller; + +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.HashMap; + +@RestController +public class IndexController { + + @GetMapping("/hello") + public Object index() { + return new HashMap<>(); + } +} diff --git a/src/main/java/dev/flyfish/boot/cas/filter/CASContext.java b/src/main/java/dev/flyfish/boot/cas/filter/CASContext.java index 587a93d..d3f61e3 100644 --- a/src/main/java/dev/flyfish/boot/cas/filter/CASContext.java +++ b/src/main/java/dev/flyfish/boot/cas/filter/CASContext.java @@ -13,17 +13,16 @@ import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebSession; -import org.springframework.web.server.session.WebSessionManager; -import org.springframework.web.server.session.WebSessionStore; import reactor.core.publisher.Mono; +import java.net.URI; import java.util.Map; /** * cas 上下文 */ @RequiredArgsConstructor(access = AccessLevel.PRIVATE) -class CASContext { +public class CASContext { @Getter private String ticket; @@ -38,6 +37,10 @@ class CASContext { private final Map> parameters = new ConcurrentReferenceHashMap<>(); + @Setter + @Getter + private String username; + static Mono create(ServerWebExchange exchange, WebFilterChain chain) { return new CASContext(exchange, chain).init(); } @@ -50,21 +53,20 @@ class CASContext { Mono sessionMono = exchange.getSession() .doOnNext(session -> this.session = session); return Mono.zipDelayError(ticketMono, sessionMono) + .onErrorContinue((e, v) -> e.printStackTrace()) .thenReturn(this); } + boolean isTokenRequest() { + return StringUtils.hasText(ticket); + } + Mono filter() { return chain.filter(exchange); } - /** - * 获取参数 - * - * @param key 键 - * @return 异步结果 - */ - Mono getParameter(String key) { - return parameters.computeIfAbsent(key, this::computeParameter); + Mono redirect(String url) { + return chain.filter(exchange.mutate().request(builder -> builder.uri(URI.create(url))).build()); } ServerHttpRequest getRequest() { @@ -93,6 +95,20 @@ class CASContext { .mapNotNull(formData -> formData.getFirst(key)); } + void setSessionAttribute(String key, Object value) { + session.getAttributes().put(key, value); + } + + /** + * 获取参数 + * + * @param key 键 + * @return 异步结果 + */ + private Mono getParameter(String key) { + return parameters.computeIfAbsent(key, this::computeParameter); + } + private Mono computeParameter(String key) { return this.readParameter(key).cache(); } diff --git a/src/main/java/dev/flyfish/boot/cas/filter/CASContextInit.java b/src/main/java/dev/flyfish/boot/cas/filter/CASContextInit.java new file mode 100644 index 0000000..da15b35 --- /dev/null +++ b/src/main/java/dev/flyfish/boot/cas/filter/CASContextInit.java @@ -0,0 +1,13 @@ +package dev.flyfish.boot.cas.filter; + +/** + * 上下文初始化逻辑 + * + * @author wangyu + */ +public interface CASContextInit { + + String getTranslatorUser(String username); + + void initContext(CASContext casContext, String username); +} diff --git a/src/main/java/dev/flyfish/boot/cas/filter/CASFilter.java b/src/main/java/dev/flyfish/boot/cas/filter/CASFilter.java index 4c8e3b8..411a6a7 100644 --- a/src/main/java/dev/flyfish/boot/cas/filter/CASFilter.java +++ b/src/main/java/dev/flyfish/boot/cas/filter/CASFilter.java @@ -1,13 +1,15 @@ package dev.flyfish.boot.cas.filter; -import edu.yale.its.tp.cas.client.*; +import edu.yale.its.tp.cas.client.CASAuthenticationException; +import edu.yale.its.tp.cas.client.CASReceipt; +import edu.yale.its.tp.cas.client.ProxyTicketValidator; import edu.yale.its.tp.cas.util.XmlUtils; +import lombok.Getter; import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpCookie; import org.springframework.http.HttpMethod; -import org.springframework.http.codec.HttpMessageReader; -import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.http.server.reactive.ServerHttpRequest; -import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; @@ -15,12 +17,13 @@ import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebSession; import reactor.core.publisher.Mono; -import java.io.IOException; import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; +import java.net.URI; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; +import java.util.Objects; /** * cas filter的webflux实现 @@ -49,15 +52,32 @@ public class CASFilter implements WebFilter { static final String CAS_FILTER_EXCLUSION = "edu.yale.its.tp.cas.client.filter.filterExclusion"; private final CASParameter parameter; + private final CASContextInit initializer; + @Getter + private final SessionMappingStorage sessionMappingStorage = new SessionMappingStorage.HashMapBackedSessionStorage(); - private static SessionMappingStorage SESSION_MAPPING_STORAGE = new SessionMappingStorage.HashMapBackedSessionStorage(); + public CASFilter(CASParameter parameter) { + this.parameter = parameter.checked(); + this.initializer = createInitializer(); + } - private List> messageReaders; - - public CASFilter(CASParameter parameter, ServerCodecConfigurer codecConfigurer) { - parameter.check(); - this.parameter = parameter; - this.messageReaders = codecConfigurer.getReaders(); + private CASContextInit createInitializer() { + if (null != parameter.casInitContextClass) { + try { + // 未正确配置类型,抛弃 + Class cls = parameter.casInitContextClass.asSubclass(CASContextInit.class); + // 实例化对象并返回 + return cls.getConstructor().newInstance(); + } catch (ClassCastException e) { + log.warn("cas context init class not implements CASContextInit", e); + } catch (IllegalArgumentException | IllegalAccessException | InstantiationException + | SecurityException | NoSuchMethodException e) { + log.error("error when initialize the context init class", e); + } catch (InvocationTargetException e) { + log.error("error when create the cas context initializer's instance!", e); + } + } + return null; } private boolean isReceiptAcceptable(CASReceipt receipt) { @@ -70,12 +90,12 @@ public class CASFilter implements WebFilter { } } - private CASReceipt getAuthenticatedUser(ServerHttpRequest request, String ticket) throws CASAuthenticationException { + private CASReceipt getAuthenticatedUser(CASContext context) throws CASAuthenticationException { log.trace("entering getAuthenticatedUser()"); ProxyTicketValidator pv = new ProxyTicketValidator(); pv.setCasValidateUrl(parameter.casValidate); - pv.setServiceTicket(ticket); - pv.setService(this.getService(request)); + pv.setServiceTicket(context.getTicket()); + pv.setService(this.getService(context)); pv.setRenew(parameter.casRenew); if (parameter.casProxyCallbackUrl != null) { pv.setProxyCallbackUrl(parameter.casProxyCallbackUrl); @@ -86,23 +106,71 @@ public class CASFilter implements WebFilter { return CASReceipt.getReceipt(pv); } - private String getService(ServerHttpRequest request) { + private String getService(CASContext context) { log.trace("entering getService()"); + if (parameter.casServerName == null && parameter.casServiceUrl == null) { throw new IllegalArgumentException("need one of the following configuration parameters: edu.yale.its.tp.cas.client.filter.serviceUrl or edu.yale.its.tp.cas.client.filter.serverName"); - } else { - String serviceString; - if (parameter.casServiceUrl != null) { - serviceString = URLEncoder.encode(parameter.casServiceUrl, StandardCharsets.UTF_8); - } else { - serviceString = Util.getService(request, parameter.casServerName); - } - - if (log.isTraceEnabled()) { - log.trace("returning from getService() with service [{}]", serviceString); - } - return serviceString; } + + String serviceString; + if (parameter.casServiceUrl != null) { + serviceString = URLEncoder.encode(parameter.casServiceUrl, StandardCharsets.UTF_8); + } else { + serviceString = computeService(context, parameter.casServerName); + } + + if (log.isTraceEnabled()) { + log.trace("returning from getService() with service [{}]", serviceString); + } + return serviceString; + } + + /** + * 计算服务地址,主要是替换url中server的部分,并去除ticket + * + * @param context 上下文 + * @param server 服务 + * @return 结果 + */ + public static String computeService(CASContext context, String server) { + if (log.isTraceEnabled()) { + log.trace("entering getService({}, {})", context, server); + } + + + if (server == null) { + log.error("getService() argument \"server\" was illegally null."); + throw new IllegalArgumentException("name of server is required"); + } + + URI uri = context.getRequest().getURI(); + + StringBuilder sb = new StringBuilder(); + + sb.append(uri.getScheme()).append("://").append(server).append(uri.getPath()); + if (uri.getQuery() != null) { + String query = uri.getQuery(); + + int ticketLoc = query.indexOf("ticket="); + if (ticketLoc == -1) { + sb.append("?").append(query); + } else if (ticketLoc > 0) { + ticketLoc = query.indexOf("&ticket="); + if (ticketLoc == -1) { + sb.append("?").append(query); + } else if (ticketLoc > 0) { + sb.append("?").append(query, 0, ticketLoc); + } + } + } + + String encodedService = URLEncoder.encode(sb.toString(), StandardCharsets.UTF_8); + if (log.isTraceEnabled()) { + log.trace("returning from getService() with encoded service [{}]", encodedService); + } + + return encodedService; } /** @@ -110,71 +178,58 @@ public class CASFilter implements WebFilter { * * @param context 上下文 * @return 结果 - * @throws IOException 异常 */ private Mono redirectToCAS(CASContext context) { ServerHttpRequest request = context.getRequest(); - ServerHttpResponse response = context.getResponse(); String sessionId = context.getSession().getId(); + log.trace("entering redirectToCAS()"); - String casLoginString = parameter.casLogin + "?service=" + this.getService(request) + - (parameter.casRenew ? "&renew=true" : "") + (parameter.casGateway ? "&gateway=true" : ""); - String sCookie; - sCookie = parameter.casServerName + request.getPath().contextPath().value(); - casLoginString = casLoginString + "&appId=" + sCookie + "&sessionId=" + sessionId; + StringBuilder casLoginString = new StringBuilder() + .append(parameter.casLogin) + .append("?service=").append(this.getService(context)) + .append(parameter.casRenew ? "&renew=true" : "") + .append(parameter.casGateway ? "&gateway=true" : ""); - sCookie = request.getHeaders().getFirst("Cookie"); - String cookie = null; - if (sCookie != null) { - String[] sCookies = sCookie.split(";"); - - for (int i = 0; i < sCookies.length; ++i) { - if (sCookies[i].indexOf("JSESSIONID=") != -1) { - cookie = sCookies[i].split("JSESSIONID=")[1]; - } - } + if (StringUtils.hasText(sessionId)) { + String appId = parameter.casServerName + request.getPath().contextPath().value(); + casLoginString.append("&appId=").append(appId) + .append("&sessionId=").append(sessionId); } - if (cookie != null && !cookie.equals("null") && !cookie.equals(request.getSession().getId())) { - casLoginString = casLoginString + "&timeOut=" + cookie; - if (log.isDebugEnabled()) { - log.debug("Session is timeout. The timeout session is {}", cookie); - } + List cookies = request.getCookies().get("JSESSIONID"); + if (!CollectionUtils.isEmpty(cookies)) { + cookies.stream() + .filter(Objects::nonNull) + .map(HttpCookie::getValue) + .filter(cookie -> !cookie.equals("null") && !cookie.equals(sessionId)) + .peek(cookie -> log.debug("Session is timeout. The timeout session is {}", cookie)) + .findFirst() + .ifPresent(cookie -> casLoginString.append("&timeOut=").append(cookie)); } - if (log.isDebugEnabled()) { - log.debug("Redirecting browser to [" + casLoginString + ")"); - } - - response.sendRedirect(casLoginString); - if (log.isTraceEnabled()) { - log.trace("returning from redirectToCAS()"); - } + log.debug("Redirecting browser to [{})", casLoginString); + log.trace("returning from redirectToCAS()"); + return context.redirect(casLoginString.toString()); } - private void redirectToInitFailure(HttpServletRequest request, HttpServletResponse response, String cause) throws IOException, ServletException { + private Mono redirectToInitFailure(CASContext context, String cause) { log.trace("entering redirectToInitFailure()"); - String casLoginString = this.casLogin + "?action=initFailure"; + String casLoginString = parameter.casLogin + "?action=initFailure"; if (cause != null && cause.equals("Illegal user")) { - casLoginString = casLoginString + "&userIllegal=true"; + casLoginString += "&userIllegal=true"; } - String locale = request.getParameter("locale"); + String locale = context.getQuery("locale"); if (locale != null) { - casLoginString = casLoginString + "&locale=" + locale; + casLoginString += "&locale=" + locale; } - log.debug("Redirecting browser to [" + casLoginString + ")"); - - response.sendRedirect(casLoginString); + log.debug("Redirecting browser to [{})", casLoginString); log.trace("returning from redirectToInitFailure()"); - } - - public static SessionMappingStorage getSessionMappingStorage() { - return SESSION_MAPPING_STORAGE; + return context.redirect(casLoginString); } private boolean isExclusion(String url) { @@ -186,117 +241,114 @@ public class CASFilter implements WebFilter { } private Mono translate(CASContext context) { + // 是代理回调地址,通过 if (parameter.casProxyCallbackUrl != null && parameter.casProxyCallbackUrl.endsWith(context.getPath()) && context.getQuery("pgtId") != null && context.getQuery("pgtIou") != null) { log.trace("passing through what we hope is CAS's request for proxy ticket receptor."); return context.filter(); - } else { - if (parameter.wrapRequest) { - log.trace("Wrapping request with CASFilterRequestWrapper."); - // todo 暂时啥也不干,看看有无问题 -// request = new CASFilterRequestWrapper((HttpServletRequest) request); - } - - WebSession session = context.getSession(); - // 使用了用户标记,快速跳过 - if (parameter.userLoginMark != null && session.getAttribute(parameter.userLoginMark) != null) { - return context.filter(); - } - - // 获取receipt - CASReceipt receipt = session.getAttribute(CAS_FILTER_RECEIPT); - if (receipt != null && this.isReceiptAcceptable(receipt)) { - log.trace("CAS_FILTER_RECEIPT attribute was present and acceptable - passing request through filter.."); - return context.filter(); - } - // 跳过请求 - if (this.isExclusion(context.getPath())) { - return context.filter(); - } - - // 判断票据 - String ticket = context.getTicket(); - if (StringUtils.hasText(ticket)) { - try { - receipt = this.getAuthenticatedUser(context.getRequest(), ticket); - } catch (CASAuthenticationException var22) { - return this.redirectToCAS(context); - } - - if (!this.isReceiptAcceptable(receipt)) { - throw new IllegalStateException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]"); - } else { - String pt = context.getQuery("pt"); - if (StringUtils.hasText(pt)) { - session.getAttributes().put(pt, receipt); - } - - String userName = receipt.getUserName(); - if (StringUtils.hasText(parameter.casInitContextClass)) { - try { - Class cls = Class.forName(parameter.casInitContextClass); - Object obj = cls.getConstructor().newInstance(); - if (obj instanceof IContextInit) { - Method translatorMethod = cls.getMethod("getTranslatorUser", String.class); - userName = (String) translatorMethod.invoke(obj, userName); - Method initContextMethod = cls.getMethod("initContext", ServletRequest.class, ServletResponse.class, FilterChain.class, String.class); - initContextMethod.invoke(obj, request, response, fc, userName); - } - } catch (ClassNotFoundException | IllegalArgumentException | IllegalAccessException | - InstantiationException | SecurityException | NoSuchMethodException e) { - e.printStackTrace(); - } catch (InvocationTargetException var19) { - InvocationTargetException e = var19; - String cause = e.getCause().getMessage(); - session.setAttribute("initFailure", cause); - this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause); - e.printStackTrace(); - return Mono.empty(); - } - } - - session.setAttribute("edu.yale.its.tp.cas.client.filter.user", userName); - session.setAttribute("edu.yale.its.tp.cas.client.filter.receipt", receipt); - session.removeAttribute(CAS_FILTER_GATEWAYED); - - log.trace("validated ticket to get authenticated receipt [{}], now passing request along filter chain.", receipt); - - log.trace("returning from doFilter()"); - - return context.filter(); - } - return Mono.empty(); - } else { - log.trace("CAS ticket was not present on request."); - boolean didGateway = Boolean.valueOf((String) session.getAttribute(CAS_FILTER_GATEWAYED)); - if (parameter.casLogin == null) { - log.error("casLogin was not set, so filter cannot redirect request for authentication."); - throw new IllegalArgumentException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter"); - } else if (!didGateway) { - log.trace("Did not previously gateway. Setting session attribute to true."); - ((HttpServletRequest) request).setAttribute("sessionId", session.getId()); - session.setAttribute(CAS_FILTER_GATEWAYED, "true"); - this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); - } else { - log.trace("Previously gatewayed."); - if (!this.casGateway && session.getAttribute("edu.yale.its.tp.cas.client.filter.user") == null) { - if (session.getAttribute("initFailure") != null) { - String cause = (String) session.getAttribute("initFailure"); - this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause); - } else { - ((HttpServletRequest) request).setAttribute("sessionId", session.getId()); - session.setAttribute(CAS_FILTER_GATEWAYED, "true"); - this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); - } - - } else { - log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain."); - fc.doFilter((ServletRequest) request, response); - } - } - } - } + + // 请求包装,增强请求并完成自定义功能 + if (parameter.wrapRequest) { + log.trace("Wrapping request with CASFilterRequestWrapper."); + // todo 暂时啥也不干,看看有无问题 +// request = new CASFilterRequestWrapper((HttpServletRequest) request); + } + + WebSession session = context.getSession(); + Map sessionAttributes = session.getAttributes(); + + // 使用了用户标记,快速跳过 + if (parameter.userLoginMark != null && session.getAttribute(parameter.userLoginMark) != null) { + return context.filter(); + } + + // 获取receipt,若存在,则通过 + CASReceipt receipt = session.getAttribute(CAS_FILTER_RECEIPT); + if (receipt != null && this.isReceiptAcceptable(receipt)) { + log.trace("CAS_FILTER_RECEIPT attribute was present and acceptable - passing request through filter.."); + return context.filter(); + } + + // 命中排除地址,跳过请求 + if (this.isExclusion(context.getPath())) { + return context.filter(); + } + + // 判断票据 + String ticket = context.getTicket(); + // 存在票据时,验证票据 + if (StringUtils.hasText(ticket)) { + try { + receipt = this.getAuthenticatedUser(context); + } catch (CASAuthenticationException e) { + return this.redirectToCAS(context); + } + + if (!this.isReceiptAcceptable(receipt)) { + throw new IllegalStateException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]"); + } + + // 记录receipt + String pt = context.getQuery("pt"); + if (StringUtils.hasText(pt)) { + context.setSessionAttribute(pt, receipt); + } + + // 获取到用户名 + String userName = receipt.getUserName(); + // 尝试初始化 + if (null != initializer) { + try { + String translated = initializer.getTranslatorUser(userName); + log.debug("translated username: {} to {}", userName, translated); + initializer.initContext(context, translated); + } catch (Exception e) { + String cause = e.getCause().getMessage(); + context.setSessionAttribute("initFailure", cause); + return this.redirectToInitFailure(context, cause); + } + } + + sessionAttributes.put(CAS_FILTER_USER, userName); + sessionAttributes.put(CAS_FILTER_RECEIPT, receipt); + sessionAttributes.remove(CAS_FILTER_GATEWAYED); + + if (log.isTraceEnabled()) { + log.trace("validated ticket to get authenticated receipt [{}], now passing request along filter chain.", receipt); + log.trace("returning from doFilter()"); + } + + return context.filter(); + } + + // 不存在票据,跳转验证 + log.trace("CAS ticket was not present on request."); + boolean didGateway = Boolean.parseBoolean(session.getAttribute(CAS_FILTER_GATEWAYED)); + if (parameter.casLogin == null) { + log.error("casLogin was not set, so filter cannot redirect request for authentication."); + throw new IllegalArgumentException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter"); + } + + if (!didGateway) { + log.trace("Did not previously gateway. Setting session attribute to true."); + sessionAttributes.put(CAS_FILTER_GATEWAYED, "true"); + return this.redirectToCAS(context); + } + + log.trace("Previously gatewayed."); + if (!parameter.casGateway && session.getAttribute(CAS_FILTER_USER) == null) { + if (session.getAttribute("initFailure") != null) { + String cause = session.getAttribute("initFailure"); + return this.redirectToInitFailure(context, cause); + } + + sessionAttributes.put(CAS_FILTER_GATEWAYED, "true"); + return this.redirectToCAS(context); + } + + log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain."); + return context.filter(); } /** @@ -306,19 +358,30 @@ public class CASFilter implements WebFilter { * @return 结果 */ private Mono handle(CASContext context) { + // 优先处理token请求 + if (context.isTokenRequest()) { + String sessionId = context.getSession().getId(); + log.debug("Storing session identifier for {}", sessionId); + + // 包括ticket,尝试重新替换session + return sessionMappingStorage.removeBySessionById(sessionId) + .onErrorContinue((e, v) -> log.debug("error when remove session")) + .then(Mono.defer(() -> sessionMappingStorage.addSessionById(context.getTicket(), context.getSession()) + .then(translate(context)))); + } // post请求需要特殊处理 if (context.getMethod() == HttpMethod.POST) { - // 此处可能要求安全的获取参数(单独针对退出请求) + // 通过form表单获取注销请求,处理注销逻辑 return context.getFormData("logoutRequest") .doOnNext(payload -> log.trace("Logout request=[{}]", payload)) .defaultIfEmpty("") .flatMap(payload -> { if (StringUtils.hasText(payload)) { - String sessionIdentifier = XmlUtils.getTextForElement(payload, "SessionIndex"); - if (StringUtils.hasText(sessionIdentifier)) { + String token = XmlUtils.getTextForElement(payload, "SessionIndex"); + if (StringUtils.hasText(token)) { // 满足条件时断路 - return SESSION_MAPPING_STORAGE.removeSessionByMappingId(sessionIdentifier) - .doOnNext(session -> log.debug("Invalidating session [{}] for ST [{}]", session.getId(), sessionIdentifier)) + return sessionMappingStorage.removeSessionByMappingId(token) + .doOnNext(session -> log.debug("Invalidating session [{}] for ST [{}]", session.getId(), token)) .flatMap(WebSession::invalidate) .doOnError(IllegalStateException.class, e -> log.debug(e.getMessage(), e)) .onErrorComplete(); @@ -327,20 +390,8 @@ public class CASFilter implements WebFilter { // 继续执行 return translate(context); }); - } else { - String ticket = context.getTicket(); - String sessionId = context.getSession().getId(); - log.debug("Storing session identifier for {}", sessionId); - - // 包括ticket,尝试重新替换session - if (StringUtils.hasText(ticket)) { - return SESSION_MAPPING_STORAGE.removeBySessionById(sessionId) - .onErrorResume(e -> SESSION_MAPPING_STORAGE.addSessionById(ticket, context.getSession())) - .then(Mono.defer(() -> translate(context))); - } - - return translate(context); } + return translate(context); } @@ -353,7 +404,7 @@ public class CASFilter implements WebFilter { if (log.isTraceEnabled()) { log.trace("entering doFilter()"); } - // 执行中断策略 + // 执行跳过策略 String pt = context.getQuery("pt"); if (StringUtils.hasText(pt)) { if (session.getAttribute(pt) != null) { diff --git a/src/main/java/dev/flyfish/boot/cas/filter/CASLoginFilter.java b/src/main/java/dev/flyfish/boot/cas/filter/CASLoginFilter.java new file mode 100644 index 0000000..7ac7d21 --- /dev/null +++ b/src/main/java/dev/flyfish/boot/cas/filter/CASLoginFilter.java @@ -0,0 +1,20 @@ +package dev.flyfish.boot.cas.filter; + +/** + * 登录过滤器,旨在缓存用户名 + * + * @author wangyu + */ +public class CASLoginFilter implements CASContextInit { + public static String CONST_CAS_USERNAME = "const_cas_username"; + + @Override + public String getTranslatorUser(String username) { + return username; + } + + @Override + public void initContext(CASContext casContext, String username) { + casContext.setSessionAttribute(CONST_CAS_USERNAME, username); + } +} diff --git a/src/main/java/dev/flyfish/boot/cas/filter/CASParameter.java b/src/main/java/dev/flyfish/boot/cas/filter/CASParameter.java index 7fdef3c..a9827d6 100644 --- a/src/main/java/dev/flyfish/boot/cas/filter/CASParameter.java +++ b/src/main/java/dev/flyfish/boot/cas/filter/CASParameter.java @@ -31,7 +31,7 @@ public class CASParameter { String casProxyCallbackUrl; @JsonAlias(CASFilter.CAS_FILTER_INITCONTEXTCLASS) - String casInitContextClass; + Class casInitContextClass; @JsonAlias(CASFilter.RENEW_INIT_PARAM) boolean casRenew; @@ -63,7 +63,7 @@ public class CASParameter { /** * 检查配置参数是否有误 */ - public void check() { + public CASParameter checked() { if (this.casGateway && this.casRenew) { throw new IllegalArgumentException("gateway and renew cannot both be true in filter configuration"); } else if (this.casServerName != null && this.casServiceUrl != null) { @@ -73,5 +73,6 @@ public class CASParameter { } else if (this.casValidate == null) { throw new IllegalArgumentException("validateUrl parameter must be set."); } + return this; } } diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties deleted file mode 100644 index 401009c..0000000 --- a/src/main/resources/application.properties +++ /dev/null @@ -1 +0,0 @@ -spring.application.name=cas diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml new file mode 100644 index 0000000..9530c01 --- /dev/null +++ b/src/main/resources/application.yml @@ -0,0 +1,13 @@ +server: + port: 8080 + +spring: + application: + name: cas + +cas: + filter: + cas-login: https://sdsfzt.sxu.edu.cn/authserver/login + cas-validate: https://sdsfzt.sxu.edu.cn/authserver/serviceValidate + cas-server-name: 127.0.0.1:8080 + cas-init-context-class: dev.flyfish.boot.cas.filter.CASLoginFilter