From a1cd8a2ba2ee94072652f1bbdd66736cc9c6496a Mon Sep 17 00:00:00 2001 From: wangyu Date: Sat, 12 Oct 2024 00:32:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9A=82=E5=AD=98=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=EF=BC=8C=E8=BF=98=E6=9C=89=E5=BE=88=E5=A4=9A=E8=B7=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../flyfish/boot/cas/filter/CASContext.java | 111 ++++++ .../flyfish/boot/cas/filter/CASFilter.java | 364 ++++++++---------- .../cas/filter/SessionMappingStorage.java | 4 +- 3 files changed, 266 insertions(+), 213 deletions(-) create mode 100644 src/main/java/dev/flyfish/boot/cas/filter/CASContext.java diff --git a/src/main/java/dev/flyfish/boot/cas/filter/CASContext.java b/src/main/java/dev/flyfish/boot/cas/filter/CASContext.java new file mode 100644 index 0000000..587a93d --- /dev/null +++ b/src/main/java/dev/flyfish/boot/cas/filter/CASContext.java @@ -0,0 +1,111 @@ +package dev.flyfish.boot.cas.filter; + +import lombok.AccessLevel; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.ConcurrentReferenceHashMap; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.session.WebSessionManager; +import org.springframework.web.server.session.WebSessionStore; +import reactor.core.publisher.Mono; + +import java.util.Map; + +/** + * cas 上下文 + */ +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +class CASContext { + + @Getter + private String ticket; + + @Setter + @Getter + private WebSession session; + + private final ServerWebExchange exchange; + + private final WebFilterChain chain; + + private final Map> parameters = new ConcurrentReferenceHashMap<>(); + + static Mono create(ServerWebExchange exchange, WebFilterChain chain) { + return new CASContext(exchange, chain).init(); + } + + private Mono init() { + Mono ticketMono = getParameter("ticket") + .filter(StringUtils::hasText) + .doOnNext(ticket -> this.ticket = ticket); + // 此处必须保证session不为空 + Mono sessionMono = exchange.getSession() + .doOnNext(session -> this.session = session); + return Mono.zipDelayError(ticketMono, sessionMono) + .thenReturn(this); + } + + Mono filter() { + return chain.filter(exchange); + } + + /** + * 获取参数 + * + * @param key 键 + * @return 异步结果 + */ + Mono getParameter(String key) { + return parameters.computeIfAbsent(key, this::computeParameter); + } + + ServerHttpRequest getRequest() { + return exchange.getRequest(); + } + + ServerHttpResponse getResponse() { + return exchange.getResponse(); + } + + String getPath() { + return exchange.getRequest().getPath().value(); + } + + HttpMethod getMethod() { + return exchange.getRequest().getMethod(); + } + + String getQuery(String key) { + ServerHttpRequest request = exchange.getRequest(); + return request.getQueryParams().getFirst(key); + } + + Mono getFormData(String key) { + return exchange.getFormData() + .mapNotNull(formData -> formData.getFirst(key)); + } + + private Mono computeParameter(String key) { + return this.readParameter(key).cache(); + } + + private Mono readParameter(String key) { + String query = getQuery(key); + if (StringUtils.hasText(query)) { + return Mono.just(query); + } + MediaType mediaType = exchange.getRequest().getHeaders().getContentType(); + if (null != mediaType && mediaType.isCompatibleWith(MediaType.APPLICATION_FORM_URLENCODED)) { + return exchange.getFormData().mapNotNull(formData -> formData.getFirst(key)); + } + return Mono.empty(); + } +} diff --git a/src/main/java/dev/flyfish/boot/cas/filter/CASFilter.java b/src/main/java/dev/flyfish/boot/cas/filter/CASFilter.java index 4583589..4c8e3b8 100644 --- a/src/main/java/dev/flyfish/boot/cas/filter/CASFilter.java +++ b/src/main/java/dev/flyfish/boot/cas/filter/CASFilter.java @@ -2,15 +2,12 @@ package dev.flyfish.boot.cas.filter; import edu.yale.its.tp.cas.client.*; import edu.yale.its.tp.cas.util.XmlUtils; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.http.server.reactive.ServerHttpRequest; -import org.springframework.lang.NonNull; -import org.springframework.util.MultiValueMap; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; @@ -22,8 +19,8 @@ import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; import java.util.List; -import java.util.Objects; /** * cas filter的webflux实现 @@ -31,10 +28,9 @@ import java.util.Objects; * @author wangyu * 实现相关核心逻辑,完成鉴权信息抽取 */ +@Slf4j public class CASFilter implements WebFilter { - private static final Log log = LogFactory.getLog(CASFilter.class); - public static final String LOGIN_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.loginUrl"; public static final String VALIDATE_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.validateUrl"; public static final String SERVICE_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.serviceUrl"; @@ -54,7 +50,7 @@ public class CASFilter implements WebFilter { private final CASParameter parameter; - private static SessionMappingStorage SESSION_MAPPING_STORAGE = new HashMapBackedSessionMappingStorage(); + private static SessionMappingStorage SESSION_MAPPING_STORAGE = new SessionMappingStorage.HashMapBackedSessionStorage(); private List> messageReaders; @@ -74,57 +70,61 @@ public class CASFilter implements WebFilter { } } - private CASReceipt getAuthenticatedUser(ServerWebExchange exchange, String ticket) throws CASAuthenticationException { + private CASReceipt getAuthenticatedUser(ServerHttpRequest request, String ticket) throws CASAuthenticationException { log.trace("entering getAuthenticatedUser()"); ProxyTicketValidator pv = new ProxyTicketValidator(); pv.setCasValidateUrl(parameter.casValidate); pv.setServiceTicket(ticket); - pv.setService(this.getService(exchange.getRequest())); + pv.setService(this.getService(request)); pv.setRenew(parameter.casRenew); if (parameter.casProxyCallbackUrl != null) { pv.setProxyCallbackUrl(parameter.casProxyCallbackUrl); } - if (log.isDebugEnabled()) { - log.debug("about to validate ProxyTicketValidator: [" + pv + "]"); - } + log.debug("about to validate ProxyTicketValidator: [{}]", pv); return CASReceipt.getReceipt(pv); } - private String getService(ServerHttpRequest request) throws ServletException { + private String getService(ServerHttpRequest request) { log.trace("entering getService()"); - if (this.casServerName == null && this.casServiceUrl == null) { - throw new ServletException("need one of the following configuration parameters: edu.yale.its.tp.cas.client.filter.serviceUrl or edu.yale.its.tp.cas.client.filter.serverName"); + if (parameter.casServerName == null && parameter.casServiceUrl == null) { + throw new IllegalArgumentException("need one of the following configuration parameters: edu.yale.its.tp.cas.client.filter.serviceUrl or edu.yale.its.tp.cas.client.filter.serverName"); } else { String serviceString; - if (this.casServiceUrl != null) { - serviceString = URLEncoder.encode(this.casServiceUrl); + if (parameter.casServiceUrl != null) { + serviceString = URLEncoder.encode(parameter.casServiceUrl, StandardCharsets.UTF_8); } else { - serviceString = Util.getService(request, this.casServerName); + serviceString = Util.getService(request, parameter.casServerName); } if (log.isTraceEnabled()) { - log.trace("returning from getService() with service [" + serviceString + "]"); + log.trace("returning from getService() with service [{}]", serviceString); } - return serviceString; } } - private void redirectToCAS(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { - if (log.isTraceEnabled()) { - log.trace("entering redirectToCAS()"); - } + /** + * 核心,跳转cas服务器鉴权 + * + * @param context 上下文 + * @return 结果 + * @throws IOException 异常 + */ + private Mono redirectToCAS(CASContext context) { + ServerHttpRequest request = context.getRequest(); + ServerHttpResponse response = context.getResponse(); + String sessionId = context.getSession().getId(); + log.trace("entering redirectToCAS()"); - String casLoginString = this.casLogin + "?service=" + this.getService(request) + (this.casRenew ? "&renew=true" : "") + (this.casGateway ? "&gateway=true" : ""); + String casLoginString = parameter.casLogin + "?service=" + this.getService(request) + + (parameter.casRenew ? "&renew=true" : "") + (parameter.casGateway ? "&gateway=true" : ""); String sCookie; - if (request.getAttribute("sessionId") != null) { - sCookie = this.casServerName + request.getContextPath(); - casLoginString = casLoginString + "&appId=" + sCookie + "&sessionId=" + request.getAttribute("sessionId"); - } + sCookie = parameter.casServerName + request.getPath().contextPath().value(); + casLoginString = casLoginString + "&appId=" + sCookie + "&sessionId=" + sessionId; - sCookie = request.getHeader("Cookie"); + sCookie = request.getHeaders().getFirst("Cookie"); String cookie = null; if (sCookie != null) { String[] sCookies = sCookie.split(";"); @@ -139,7 +139,7 @@ public class CASFilter implements WebFilter { if (cookie != null && !cookie.equals("null") && !cookie.equals(request.getSession().getId())) { casLoginString = casLoginString + "&timeOut=" + cookie; if (log.isDebugEnabled()) { - log.debug("Session is timeout. The timeout session is " + cookie); + log.debug("Session is timeout. The timeout session is {}", cookie); } } @@ -155,9 +155,7 @@ public class CASFilter implements WebFilter { } private void redirectToInitFailure(HttpServletRequest request, HttpServletResponse response, String cause) throws IOException, ServletException { - if (log.isTraceEnabled()) { - log.trace("entering redirectToInitFailure()"); - } + log.trace("entering redirectToInitFailure()"); String casLoginString = this.casLogin + "?action=initFailure"; if (cause != null && cause.equals("Illegal user")) { @@ -169,40 +167,29 @@ public class CASFilter implements WebFilter { casLoginString = casLoginString + "&locale=" + locale; } - if (log.isDebugEnabled()) { - log.debug("Redirecting browser to [" + casLoginString + ")"); - } + log.debug("Redirecting browser to [" + casLoginString + ")"); response.sendRedirect(casLoginString); - if (log.isTraceEnabled()) { - log.trace("returning from redirectToInitFailure()"); - } - + log.trace("returning from redirectToInitFailure()"); } public static SessionMappingStorage getSessionMappingStorage() { return SESSION_MAPPING_STORAGE; } - private boolean isExclusion(ServerHttpRequest request) { + private boolean isExclusion(String url) { if (parameter.exclusions == null) { return false; } else { - String url = request.getPath().value(); return parameter.exclusions.contains(url); } } - private Mono translate(ServerWebExchange exchange, WebFilterChain chain, WebSession session) { - ServerHttpRequest request = exchange.getRequest(); - MultiValueMap params = request.getQueryParams(); - String userName; - String artifact; - - if (parameter.casProxyCallbackUrl != null && parameter.casProxyCallbackUrl.endsWith(request.getPath().value()) - && params.getFirst("pgtId") != null && params.getFirst("pgtIou") != null) { + private Mono translate(CASContext context) { + if (parameter.casProxyCallbackUrl != null && parameter.casProxyCallbackUrl.endsWith(context.getPath()) + && context.getQuery("pgtId") != null && context.getQuery("pgtIou") != null) { log.trace("passing through what we hope is CAS's request for proxy ticket receptor."); - return chain.filter(exchange); + return context.filter(); } else { if (parameter.wrapRequest) { log.trace("Wrapping request with CASFilterRequestWrapper."); @@ -210,122 +197,104 @@ public class CASFilter implements WebFilter { // request = new CASFilterRequestWrapper((HttpServletRequest) request); } + WebSession session = context.getSession(); // 使用了用户标记,快速跳过 if (parameter.userLoginMark != null && session.getAttribute(parameter.userLoginMark) != null) { - return chain.filter(exchange); + return context.filter(); } // 获取receipt CASReceipt receipt = session.getAttribute(CAS_FILTER_RECEIPT); if (receipt != null && this.isReceiptAcceptable(receipt)) { log.trace("CAS_FILTER_RECEIPT attribute was present and acceptable - passing request through filter.."); - return chain.filter(exchange); + return context.filter(); } // 跳过请求 - if (this.isExclusion(request)) { - return chain.filter(exchange); + if (this.isExclusion(context.getPath())) { + return context.filter(); } - return getParameter(exchange, "ticket") - .flatMap(ticket -> { - if (StringUtils.hasText(ticket)) { - try { - receipt = this.getAuthenticatedUser(exchange, ticket); - } catch (CASAuthenticationException var22) { - ((HttpServletRequest) request).setAttribute("sessionId", session.getId()); - this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); - return; - } - - if (!this.isReceiptAcceptable(receipt)) { - throw new ServletException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]"); - } else { - if (pt != null && pt != "") { - session.setAttribute(pt, receipt); - } - - if (session != null) { - userName = receipt.getUserName(); - if (this.casInitContextClass != null && !"".equals(this.casInitContextClass)) { - try { - Class cls = Class.forName(this.casInitContextClass); - Object obj = cls.newInstance(); - if (obj instanceof IContextInit) { - Method translatorMethod = cls.getMethod("getTranslatorUser", String.class); - userName = (String) translatorMethod.invoke(obj, userName); - Method initContextMethod = cls.getMethod("initContext", ServletRequest.class, ServletResponse.class, FilterChain.class, String.class); - initContextMethod.invoke(obj, request, response, fc, userName); - } - } catch (ClassNotFoundException var15) { - ClassNotFoundException e = var15; - e.printStackTrace(); - } catch (InstantiationException var16) { - InstantiationException e = var16; - e.printStackTrace(); - } catch (IllegalAccessException var17) { - IllegalAccessException e = var17; - e.printStackTrace(); - } catch (IllegalArgumentException var18) { - IllegalArgumentException e = var18; - e.printStackTrace(); - } catch (InvocationTargetException var19) { - InvocationTargetException e = var19; - String cause = e.getCause().getMessage(); - session.setAttribute("initFailure", cause); - this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause); - e.printStackTrace(); - return; - } catch (SecurityException var20) { - SecurityException e = var20; - e.printStackTrace(); - } catch (NoSuchMethodException var21) { - NoSuchMethodException e = var21; - e.printStackTrace(); - } - } - - session.setAttribute("edu.yale.its.tp.cas.client.filter.user", userName); - session.setAttribute("edu.yale.its.tp.cas.client.filter.receipt", receipt); - session.removeAttribute("edu.yale.its.tp.cas.client.filter.didGateway"); - } - - if (log.isTraceEnabled()) { - log.trace("validated ticket to get authenticated receipt [" + receipt + "], now passing request along filter chain."); - } - - fc.doFilter((ServletRequest) request, response); - log.trace("returning from doFilter()"); - } - } else { - log.trace("CAS ticket was not present on request."); - boolean didGateway = Boolean.valueOf((String) session.getAttribute("edu.yale.its.tp.cas.client.filter.didGateway")); - if (this.casLogin == null) { - log.fatal("casLogin was not set, so filter cannot redirect request for authentication."); - throw new ServletException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter"); - } else if (!didGateway) { - log.trace("Did not previously gateway. Setting session attribute to true."); - ((HttpServletRequest) request).setAttribute("sessionId", session.getId()); - session.setAttribute("edu.yale.its.tp.cas.client.filter.didGateway", "true"); - this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); - } else { - log.trace("Previously gatewayed."); - if (!this.casGateway && session.getAttribute("edu.yale.its.tp.cas.client.filter.user") == null) { - if (session.getAttribute("initFailure") != null) { - String cause = (String) session.getAttribute("initFailure"); - this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause); - } else { - ((HttpServletRequest) request).setAttribute("sessionId", session.getId()); - session.setAttribute("edu.yale.its.tp.cas.client.filter.didGateway", "true"); - this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); - } - - } else { - log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain."); - fc.doFilter((ServletRequest) request, response); - } + // 判断票据 + String ticket = context.getTicket(); + if (StringUtils.hasText(ticket)) { + try { + receipt = this.getAuthenticatedUser(context.getRequest(), ticket); + } catch (CASAuthenticationException var22) { + return this.redirectToCAS(context); + } + + if (!this.isReceiptAcceptable(receipt)) { + throw new IllegalStateException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]"); + } else { + String pt = context.getQuery("pt"); + if (StringUtils.hasText(pt)) { + session.getAttributes().put(pt, receipt); + } + + String userName = receipt.getUserName(); + if (StringUtils.hasText(parameter.casInitContextClass)) { + try { + Class cls = Class.forName(parameter.casInitContextClass); + Object obj = cls.getConstructor().newInstance(); + if (obj instanceof IContextInit) { + Method translatorMethod = cls.getMethod("getTranslatorUser", String.class); + userName = (String) translatorMethod.invoke(obj, userName); + Method initContextMethod = cls.getMethod("initContext", ServletRequest.class, ServletResponse.class, FilterChain.class, String.class); + initContextMethod.invoke(obj, request, response, fc, userName); } + } catch (ClassNotFoundException | IllegalArgumentException | IllegalAccessException | + InstantiationException | SecurityException | NoSuchMethodException e) { + e.printStackTrace(); + } catch (InvocationTargetException var19) { + InvocationTargetException e = var19; + String cause = e.getCause().getMessage(); + session.setAttribute("initFailure", cause); + this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause); + e.printStackTrace(); + return Mono.empty(); } - }) + } + + session.setAttribute("edu.yale.its.tp.cas.client.filter.user", userName); + session.setAttribute("edu.yale.its.tp.cas.client.filter.receipt", receipt); + session.removeAttribute(CAS_FILTER_GATEWAYED); + + log.trace("validated ticket to get authenticated receipt [{}], now passing request along filter chain.", receipt); + + log.trace("returning from doFilter()"); + + return context.filter(); + } + return Mono.empty(); + } else { + log.trace("CAS ticket was not present on request."); + boolean didGateway = Boolean.valueOf((String) session.getAttribute(CAS_FILTER_GATEWAYED)); + if (parameter.casLogin == null) { + log.error("casLogin was not set, so filter cannot redirect request for authentication."); + throw new IllegalArgumentException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter"); + } else if (!didGateway) { + log.trace("Did not previously gateway. Setting session attribute to true."); + ((HttpServletRequest) request).setAttribute("sessionId", session.getId()); + session.setAttribute(CAS_FILTER_GATEWAYED, "true"); + this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); + } else { + log.trace("Previously gatewayed."); + if (!this.casGateway && session.getAttribute("edu.yale.its.tp.cas.client.filter.user") == null) { + if (session.getAttribute("initFailure") != null) { + String cause = (String) session.getAttribute("initFailure"); + this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause); + } else { + ((HttpServletRequest) request).setAttribute("sessionId", session.getId()); + session.setAttribute(CAS_FILTER_GATEWAYED, "true"); + this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); + } + + } else { + log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain."); + fc.doFilter((ServletRequest) request, response); + } + } + } } } @@ -333,94 +302,67 @@ public class CASFilter implements WebFilter { /** * 二阶段处理,预处理特殊情况,提前中断请求 * - * @param exchange 交换信息 - * @param chain 过滤器链 - * @param session 会话 + * @param context 上下文工具 * @return 结果 */ - private Mono handle(ServerWebExchange exchange, WebFilterChain chain, @NonNull WebSession session) { - // 下一步处理信号,提前生成 - Mono translate = this.translate(exchange, chain, session); - + private Mono handle(CASContext context) { // post请求需要特殊处理 - if (exchange.getRequest().getMethod() == HttpMethod.POST) { + if (context.getMethod() == HttpMethod.POST) { // 此处可能要求安全的获取参数(单独针对退出请求) - return exchange.getFormData() - .flatMap(formData -> { - String payload = formData.getFirst("logoutRequest"); + return context.getFormData("logoutRequest") + .doOnNext(payload -> log.trace("Logout request=[{}]", payload)) + .defaultIfEmpty("") + .flatMap(payload -> { if (StringUtils.hasText(payload)) { - if (log.isTraceEnabled()) { - log.trace("Logout request=[" + payload + "]"); - } String sessionIdentifier = XmlUtils.getTextForElement(payload, "SessionIndex"); if (StringUtils.hasText(sessionIdentifier)) { - // 命中该请求,中断执行 + // 满足条件时断路 return SESSION_MAPPING_STORAGE.removeSessionByMappingId(sessionIdentifier) - .filter(Objects::nonNull) - .flatMap(savedSession -> { - String sessionId = savedSession.getId(); - if (log.isDebugEnabled()) { - log.debug("Invalidating session [" + sessionId + "] for ST [" + sessionIdentifier + "]"); - } - try { - return savedSession.invalidate(); - } catch (IllegalStateException e) { - log.debug(e, e); - } - // 中断处理 - return Mono.empty(); - }); + .doOnNext(session -> log.debug("Invalidating session [{}] for ST [{}]", session.getId(), sessionIdentifier)) + .flatMap(WebSession::invalidate) + .doOnError(IllegalStateException.class, e -> log.debug(e.getMessage(), e)) + .onErrorComplete(); } } - return translate; + // 继续执行 + return translate(context); }); } else { - String ticket = exchange.getRequest().getQueryParams().getFirst("ticket"); - if (log.isDebugEnabled()) { - log.debug("Storing session identifier for " + session.getId()); - } + String ticket = context.getTicket(); + String sessionId = context.getSession().getId(); + log.debug("Storing session identifier for {}", sessionId); // 包括ticket,尝试重新替换session if (StringUtils.hasText(ticket)) { - return SESSION_MAPPING_STORAGE.removeBySessionById(session.getId()) - .onErrorResume(e -> SESSION_MAPPING_STORAGE.addSessionById(ticket, session)) - .then(translate); + return SESSION_MAPPING_STORAGE.removeBySessionById(sessionId) + .onErrorResume(e -> SESSION_MAPPING_STORAGE.addSessionById(ticket, context.getSession())) + .then(Mono.defer(() -> translate(context))); } - return translate; + return translate(context); } } - private Mono getParameter(ServerWebExchange exchange, String key) { - ServerHttpRequest request = exchange.getRequest(); - String query = request.getQueryParams().getFirst(key); - if (StringUtils.hasText(query)) { - return Mono.just(query); - } - MediaType mediaType = request.getHeaders().getContentType(); - if (null != mediaType && mediaType.isCompatibleWith(MediaType.APPLICATION_FORM_URLENCODED)) { - return exchange.getFormData().mapNotNull(formData -> formData.getFirst(key)); - } - return Mono.empty(); - } @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { // 拦截器需要基于session判定,故提前使用 - return exchange.getSession() - .flatMap(session -> { + return CASContext.create(exchange, chain) + .flatMap(context -> { + WebSession session = context.getSession(); if (log.isTraceEnabled()) { log.trace("entering doFilter()"); } // 执行中断策略 - String pt = exchange.getRequest().getQueryParams().getFirst("pt"); + String pt = context.getQuery("pt"); if (StringUtils.hasText(pt)) { if (session.getAttribute(pt) != null) { - return chain.filter(exchange); + return context.filter(); } } - return handle(exchange, chain, session); + return handle(context); }); } + } diff --git a/src/main/java/dev/flyfish/boot/cas/filter/SessionMappingStorage.java b/src/main/java/dev/flyfish/boot/cas/filter/SessionMappingStorage.java index 393cd89..f320ced 100644 --- a/src/main/java/dev/flyfish/boot/cas/filter/SessionMappingStorage.java +++ b/src/main/java/dev/flyfish/boot/cas/filter/SessionMappingStorage.java @@ -30,9 +30,9 @@ public interface SessionMappingStorage { public Mono removeSessionByMappingId(String mappingId) { WebSession session = this.MANAGED_SESSIONS.get(mappingId); if (session != null) { - this.removeBySessionById(session.getId()); + return this.removeBySessionById(session.getId()).thenReturn(session); } - return Mono.just(session); + return Mono.empty(); } @Override