feat: 初步完成实现

This commit is contained in:
wangyu 2024-10-12 15:35:08 +08:00
parent a1cd8a2ba2
commit b978c776b5
11 changed files with 645 additions and 210 deletions

View File

@ -0,0 +1,91 @@
package dev.flyfish.boot.cas.config;
import dev.flyfish.boot.cas.config.session.WebSessionDecorator;
import dev.flyfish.boot.cas.config.session.WebSessionListener;
import dev.flyfish.boot.cas.filter.CASFilter;
import dev.flyfish.boot.cas.filter.CASParameter;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.session.DefaultWebSessionManager;
import org.springframework.web.server.session.InMemoryWebSessionStore;
import org.springframework.web.server.session.WebSessionManager;
import org.springframework.web.server.session.WebSessionStore;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.util.List;
/**
* cas核心配置
*
* @author wangyu
*/
@Configuration
public class CASConfig {
@Bean
@ConfigurationProperties("cas.filter")
public CASParameter casParameter() {
return new CASParameter();
}
@Bean
public CASFilter casFilter(CASParameter casParameter) {
return new CASFilter(casParameter);
}
@Bean
public WebSessionStore webSessionStore(WebSessionManager webSessionManager, ServerProperties serverProperties,
ObjectProvider<WebSessionListener> listeners) {
if (webSessionManager instanceof DefaultWebSessionManager defaultWebSessionManager) {
Duration timeout = serverProperties.getReactive().getSession().getTimeout();
int maxSessions = serverProperties.getReactive().getSession().getMaxSessions();
ListenableWebSessionStore sessionStore = new ListenableWebSessionStore(timeout, listeners);
sessionStore.setMaxSessions(maxSessions);
defaultWebSessionManager.setSessionStore(sessionStore);
return sessionStore;
}
throw new IllegalStateException("Cannot find web session manager");
}
/**
* 处理session销毁保证正确退出
*
* @param casFilter 过滤器
* @return 结果
*/
@Bean
public WebSessionListener singleSignOutSessionListener(CASFilter casFilter) {
return new WebSessionListener() {
@Override
public Mono<Void> onSessionInvalidated(WebSession session) {
return casFilter.getSessionMappingStorage().removeBySessionById(session.getId());
}
};
}
static final class ListenableWebSessionStore extends InMemoryWebSessionStore {
private final Duration timeout;
private final List<WebSessionListener> listeners;
private ListenableWebSessionStore(Duration timeout, ObjectProvider<WebSessionListener> listeners) {
this.timeout = timeout;
this.listeners = listeners.stream().toList();
}
public Mono<WebSession> createWebSession() {
return super.createWebSession()
.map(session -> (WebSession) new WebSessionDecorator(session, listeners))
.doOnSuccess(this::setMaxIdleTime);
}
private void setMaxIdleTime(WebSession session) {
session.setMaxIdleTime(this.timeout);
}
}
}

View File

@ -0,0 +1,194 @@
package dev.flyfish.boot.cas.config.session;
import lombok.RequiredArgsConstructor;
import org.springframework.lang.Nullable;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Map;
@RequiredArgsConstructor
public class WebSessionDecorator implements WebSession {
private final WebSession decorated;
private final List<WebSessionListener> listeners;
/**
* Return a unique session identifier.
*/
@Override
public String getId() {
return decorated.getId();
}
/**
* Return a map that holds session attributes.
*/
@Override
public Map<String, Object> getAttributes() {
return decorated.getAttributes();
}
/**
* Return the session attribute value if present.
*
* @param name the attribute name
* @return the attribute value
*/
@Nullable
@Override
public <T> T getAttribute(String name) {
return decorated.getAttribute(name);
}
/**
* Return the session attribute value or if not present raise an
* {@link IllegalArgumentException}.
*
* @param name the attribute name
* @return the attribute value
*/
@Override
public <T> T getRequiredAttribute(String name) {
return decorated.getRequiredAttribute(name);
}
/**
* Return the session attribute value, or a default, fallback value.
*
* @param name the attribute name
* @param defaultValue a default value to return instead
* @return the attribute value
*/
@Override
public <T> T getAttributeOrDefault(String name, T defaultValue) {
return decorated.getAttributeOrDefault(name, defaultValue);
}
/**
* Force the creation of a session causing the session id to be sent when
* {@link #save()} is called.
*/
@Override
public void start() {
decorated.start();
}
/**
* Whether a session with the client has been started explicitly via
* {@link #start()} or implicitly by adding session attributes.
* If "false" then the session id is not sent to the client and the
* {@link #save()} method is essentially a no-op.
*/
@Override
public boolean isStarted() {
return decorated.isStarted();
}
/**
* Generate a new id for the session and update the underlying session
* storage to reflect the new id. After a successful call {@link #getId()}
* reflects the new session id.
*
* @return completion notification (success or error)
*/
@Override
public Mono<Void> changeSessionId() {
return decorated.changeSessionId();
}
/**
* Invalidate the current session and clear session storage.
*
* @return completion notification (success or error)
*/
@Override
public Mono<Void> invalidate() {
// 后续处理
Mono<Void> consumer = Mono.defer(() -> listeners.stream()
.map(listener -> listener.onSessionInvalidated(this.decorated))
.reduce(Mono::then)
.orElse(Mono.empty()));
return decorated.invalidate().then(consumer);
}
/**
* Save the session through the {@code WebSessionStore} as follows:
* <ul>
* <li>If the session is new (i.e. created but never persisted), it must have
* been started explicitly via {@link #start()} or implicitly by adding
* attributes, or otherwise this method should have no effect.
* <li>If the session was retrieved through the {@code WebSessionStore},
* the implementation for this method must check whether the session was
* {@link #invalidate() invalidated} and if so return an error.
* </ul>
* <p>Note that this method is not intended for direct use by applications.
* Instead it is automatically invoked just before the response is
* committed.
*
* @return {@code Mono} to indicate completion with success or error
*/
@Override
public Mono<Void> save() {
return decorated.save();
}
/**
* Return {@code true} if the session expired after {@link #getMaxIdleTime()
* maxIdleTime} elapsed.
* <p>Typically expiration checks should be automatically made when a session
* is accessed, a new {@code WebSession} instance created if necessary, at
* the start of request processing so that applications don't have to worry
* about expired session by default.
*/
@Override
public boolean isExpired() {
return decorated.isExpired();
}
/**
* Return the time when the session was created.
*/
@Override
public Instant getCreationTime() {
return decorated.getCreationTime();
}
/**
* Return the last time of session access as a result of user activity such
* as an HTTP request. Together with {@link #getMaxIdleTime()
* maxIdleTimeInSeconds} this helps to determine when a session is
* {@link #isExpired() expired}.
*/
@Override
public Instant getLastAccessTime() {
return decorated.getLastAccessTime();
}
/**
* Configure the max amount of time that may elapse after the
* {@link #getLastAccessTime() lastAccessTime} before a session is considered
* expired. A negative value indicates the session should not expire.
*
* @param maxIdleTime
*/
@Override
public void setMaxIdleTime(Duration maxIdleTime) {
decorated.setMaxIdleTime(maxIdleTime);
}
/**
* Return the maximum time after the {@link #getLastAccessTime()
* lastAccessTime} before a session expires. A negative time indicates the
* session doesn't expire.
*/
@Override
public Duration getMaxIdleTime() {
return decorated.getMaxIdleTime();
}
}

View File

@ -0,0 +1,21 @@
package dev.flyfish.boot.cas.config.session;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
/**
* web session监听器
*
* @author wangyu
* 基于装饰器增强实现可灵活处理
*/
public interface WebSessionListener {
default Mono<Void> onSessionCreated(WebSession session) {
return Mono.empty();
}
default Mono<Void> onSessionInvalidated(WebSession session) {
return Mono.empty();
}
}

View File

@ -0,0 +1,16 @@
package dev.flyfish.boot.cas.controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;
@RestController
public class IndexController {
@GetMapping("/hello")
public Object index() {
return new HashMap<>();
}
}

View File

@ -13,17 +13,16 @@ import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.session.WebSessionManager;
import org.springframework.web.server.session.WebSessionStore;
import reactor.core.publisher.Mono;
import java.net.URI;
import java.util.Map;
/**
* cas 上下文
*/
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
class CASContext {
public class CASContext {
@Getter
private String ticket;
@ -38,6 +37,10 @@ class CASContext {
private final Map<String, Mono<String>> parameters = new ConcurrentReferenceHashMap<>();
@Setter
@Getter
private String username;
static Mono<CASContext> create(ServerWebExchange exchange, WebFilterChain chain) {
return new CASContext(exchange, chain).init();
}
@ -50,21 +53,20 @@ class CASContext {
Mono<WebSession> sessionMono = exchange.getSession()
.doOnNext(session -> this.session = session);
return Mono.zipDelayError(ticketMono, sessionMono)
.onErrorContinue((e, v) -> e.printStackTrace())
.thenReturn(this);
}
boolean isTokenRequest() {
return StringUtils.hasText(ticket);
}
Mono<Void> filter() {
return chain.filter(exchange);
}
/**
* 获取参数
*
* @param key
* @return 异步结果
*/
Mono<String> getParameter(String key) {
return parameters.computeIfAbsent(key, this::computeParameter);
Mono<Void> redirect(String url) {
return chain.filter(exchange.mutate().request(builder -> builder.uri(URI.create(url))).build());
}
ServerHttpRequest getRequest() {
@ -93,6 +95,20 @@ class CASContext {
.mapNotNull(formData -> formData.getFirst(key));
}
void setSessionAttribute(String key, Object value) {
session.getAttributes().put(key, value);
}
/**
* 获取参数
*
* @param key
* @return 异步结果
*/
private Mono<String> getParameter(String key) {
return parameters.computeIfAbsent(key, this::computeParameter);
}
private Mono<String> computeParameter(String key) {
return this.readParameter(key).cache();
}

View File

@ -0,0 +1,13 @@
package dev.flyfish.boot.cas.filter;
/**
* 上下文初始化逻辑
*
* @author wangyu
*/
public interface CASContextInit {
String getTranslatorUser(String username);
void initContext(CASContext casContext, String username);
}

View File

@ -1,13 +1,15 @@
package dev.flyfish.boot.cas.filter;
import edu.yale.its.tp.cas.client.*;
import edu.yale.its.tp.cas.client.CASAuthenticationException;
import edu.yale.its.tp.cas.client.CASReceipt;
import edu.yale.its.tp.cas.client.ProxyTicketValidator;
import edu.yale.its.tp.cas.util.XmlUtils;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpMethod;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
@ -15,12 +17,13 @@ import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/**
* cas filter的webflux实现
@ -49,15 +52,32 @@ public class CASFilter implements WebFilter {
static final String CAS_FILTER_EXCLUSION = "edu.yale.its.tp.cas.client.filter.filterExclusion";
private final CASParameter parameter;
private final CASContextInit initializer;
@Getter
private final SessionMappingStorage sessionMappingStorage = new SessionMappingStorage.HashMapBackedSessionStorage();
private static SessionMappingStorage SESSION_MAPPING_STORAGE = new SessionMappingStorage.HashMapBackedSessionStorage();
public CASFilter(CASParameter parameter) {
this.parameter = parameter.checked();
this.initializer = createInitializer();
}
private List<HttpMessageReader<?>> messageReaders;
public CASFilter(CASParameter parameter, ServerCodecConfigurer codecConfigurer) {
parameter.check();
this.parameter = parameter;
this.messageReaders = codecConfigurer.getReaders();
private CASContextInit createInitializer() {
if (null != parameter.casInitContextClass) {
try {
// 未正确配置类型抛弃
Class<? extends CASContextInit> cls = parameter.casInitContextClass.asSubclass(CASContextInit.class);
// 实例化对象并返回
return cls.getConstructor().newInstance();
} catch (ClassCastException e) {
log.warn("cas context init class not implements CASContextInit", e);
} catch (IllegalArgumentException | IllegalAccessException | InstantiationException
| SecurityException | NoSuchMethodException e) {
log.error("error when initialize the context init class", e);
} catch (InvocationTargetException e) {
log.error("error when create the cas context initializer's instance!", e);
}
}
return null;
}
private boolean isReceiptAcceptable(CASReceipt receipt) {
@ -70,12 +90,12 @@ public class CASFilter implements WebFilter {
}
}
private CASReceipt getAuthenticatedUser(ServerHttpRequest request, String ticket) throws CASAuthenticationException {
private CASReceipt getAuthenticatedUser(CASContext context) throws CASAuthenticationException {
log.trace("entering getAuthenticatedUser()");
ProxyTicketValidator pv = new ProxyTicketValidator();
pv.setCasValidateUrl(parameter.casValidate);
pv.setServiceTicket(ticket);
pv.setService(this.getService(request));
pv.setServiceTicket(context.getTicket());
pv.setService(this.getService(context));
pv.setRenew(parameter.casRenew);
if (parameter.casProxyCallbackUrl != null) {
pv.setProxyCallbackUrl(parameter.casProxyCallbackUrl);
@ -86,16 +106,18 @@ public class CASFilter implements WebFilter {
return CASReceipt.getReceipt(pv);
}
private String getService(ServerHttpRequest request) {
private String getService(CASContext context) {
log.trace("entering getService()");
if (parameter.casServerName == null && parameter.casServiceUrl == null) {
throw new IllegalArgumentException("need one of the following configuration parameters: edu.yale.its.tp.cas.client.filter.serviceUrl or edu.yale.its.tp.cas.client.filter.serverName");
} else {
}
String serviceString;
if (parameter.casServiceUrl != null) {
serviceString = URLEncoder.encode(parameter.casServiceUrl, StandardCharsets.UTF_8);
} else {
serviceString = Util.getService(request, parameter.casServerName);
serviceString = computeService(context, parameter.casServerName);
}
if (log.isTraceEnabled()) {
@ -103,6 +125,52 @@ public class CASFilter implements WebFilter {
}
return serviceString;
}
/**
* 计算服务地址主要是替换url中server的部分并去除ticket
*
* @param context 上下文
* @param server 服务
* @return 结果
*/
public static String computeService(CASContext context, String server) {
if (log.isTraceEnabled()) {
log.trace("entering getService({}, {})", context, server);
}
if (server == null) {
log.error("getService() argument \"server\" was illegally null.");
throw new IllegalArgumentException("name of server is required");
}
URI uri = context.getRequest().getURI();
StringBuilder sb = new StringBuilder();
sb.append(uri.getScheme()).append("://").append(server).append(uri.getPath());
if (uri.getQuery() != null) {
String query = uri.getQuery();
int ticketLoc = query.indexOf("ticket=");
if (ticketLoc == -1) {
sb.append("?").append(query);
} else if (ticketLoc > 0) {
ticketLoc = query.indexOf("&ticket=");
if (ticketLoc == -1) {
sb.append("?").append(query);
} else if (ticketLoc > 0) {
sb.append("?").append(query, 0, ticketLoc);
}
}
}
String encodedService = URLEncoder.encode(sb.toString(), StandardCharsets.UTF_8);
if (log.isTraceEnabled()) {
log.trace("returning from getService() with encoded service [{}]", encodedService);
}
return encodedService;
}
/**
@ -110,71 +178,58 @@ public class CASFilter implements WebFilter {
*
* @param context 上下文
* @return 结果
* @throws IOException 异常
*/
private Mono<Void> redirectToCAS(CASContext context) {
ServerHttpRequest request = context.getRequest();
ServerHttpResponse response = context.getResponse();
String sessionId = context.getSession().getId();
log.trace("entering redirectToCAS()");
String casLoginString = parameter.casLogin + "?service=" + this.getService(request) +
(parameter.casRenew ? "&renew=true" : "") + (parameter.casGateway ? "&gateway=true" : "");
String sCookie;
sCookie = parameter.casServerName + request.getPath().contextPath().value();
casLoginString = casLoginString + "&appId=" + sCookie + "&sessionId=" + sessionId;
StringBuilder casLoginString = new StringBuilder()
.append(parameter.casLogin)
.append("?service=").append(this.getService(context))
.append(parameter.casRenew ? "&renew=true" : "")
.append(parameter.casGateway ? "&gateway=true" : "");
sCookie = request.getHeaders().getFirst("Cookie");
String cookie = null;
if (sCookie != null) {
String[] sCookies = sCookie.split(";");
for (int i = 0; i < sCookies.length; ++i) {
if (sCookies[i].indexOf("JSESSIONID=") != -1) {
cookie = sCookies[i].split("JSESSIONID=")[1];
}
}
if (StringUtils.hasText(sessionId)) {
String appId = parameter.casServerName + request.getPath().contextPath().value();
casLoginString.append("&appId=").append(appId)
.append("&sessionId=").append(sessionId);
}
if (cookie != null && !cookie.equals("null") && !cookie.equals(request.getSession().getId())) {
casLoginString = casLoginString + "&timeOut=" + cookie;
if (log.isDebugEnabled()) {
log.debug("Session is timeout. The timeout session is {}", cookie);
}
List<HttpCookie> cookies = request.getCookies().get("JSESSIONID");
if (!CollectionUtils.isEmpty(cookies)) {
cookies.stream()
.filter(Objects::nonNull)
.map(HttpCookie::getValue)
.filter(cookie -> !cookie.equals("null") && !cookie.equals(sessionId))
.peek(cookie -> log.debug("Session is timeout. The timeout session is {}", cookie))
.findFirst()
.ifPresent(cookie -> casLoginString.append("&timeOut=").append(cookie));
}
if (log.isDebugEnabled()) {
log.debug("Redirecting browser to [" + casLoginString + ")");
}
response.sendRedirect(casLoginString);
if (log.isTraceEnabled()) {
log.debug("Redirecting browser to [{})", casLoginString);
log.trace("returning from redirectToCAS()");
return context.redirect(casLoginString.toString());
}
}
private void redirectToInitFailure(HttpServletRequest request, HttpServletResponse response, String cause) throws IOException, ServletException {
private Mono<Void> redirectToInitFailure(CASContext context, String cause) {
log.trace("entering redirectToInitFailure()");
String casLoginString = this.casLogin + "?action=initFailure";
String casLoginString = parameter.casLogin + "?action=initFailure";
if (cause != null && cause.equals("Illegal user")) {
casLoginString = casLoginString + "&userIllegal=true";
casLoginString += "&userIllegal=true";
}
String locale = request.getParameter("locale");
String locale = context.getQuery("locale");
if (locale != null) {
casLoginString = casLoginString + "&locale=" + locale;
casLoginString += "&locale=" + locale;
}
log.debug("Redirecting browser to [" + casLoginString + ")");
response.sendRedirect(casLoginString);
log.debug("Redirecting browser to [{})", casLoginString);
log.trace("returning from redirectToInitFailure()");
}
public static SessionMappingStorage getSessionMappingStorage() {
return SESSION_MAPPING_STORAGE;
return context.redirect(casLoginString);
}
private boolean isExclusion(String url) {
@ -186,11 +241,14 @@ public class CASFilter implements WebFilter {
}
private Mono<Void> translate(CASContext context) {
// 是代理回调地址通过
if (parameter.casProxyCallbackUrl != null && parameter.casProxyCallbackUrl.endsWith(context.getPath())
&& context.getQuery("pgtId") != null && context.getQuery("pgtIou") != null) {
log.trace("passing through what we hope is CAS's request for proxy ticket receptor.");
return context.filter();
} else {
}
// 请求包装增强请求并完成自定义功能
if (parameter.wrapRequest) {
log.trace("Wrapping request with CASFilterRequestWrapper.");
// todo 暂时啥也不干看看有无问题
@ -198,105 +256,99 @@ public class CASFilter implements WebFilter {
}
WebSession session = context.getSession();
Map<String, Object> sessionAttributes = session.getAttributes();
// 使用了用户标记快速跳过
if (parameter.userLoginMark != null && session.getAttribute(parameter.userLoginMark) != null) {
return context.filter();
}
// 获取receipt
// 获取receipt若存在则通过
CASReceipt receipt = session.getAttribute(CAS_FILTER_RECEIPT);
if (receipt != null && this.isReceiptAcceptable(receipt)) {
log.trace("CAS_FILTER_RECEIPT attribute was present and acceptable - passing request through filter..");
return context.filter();
}
// 跳过请求
// 命中排除地址跳过请求
if (this.isExclusion(context.getPath())) {
return context.filter();
}
// 判断票据
String ticket = context.getTicket();
// 存在票据时验证票据
if (StringUtils.hasText(ticket)) {
try {
receipt = this.getAuthenticatedUser(context.getRequest(), ticket);
} catch (CASAuthenticationException var22) {
receipt = this.getAuthenticatedUser(context);
} catch (CASAuthenticationException e) {
return this.redirectToCAS(context);
}
if (!this.isReceiptAcceptable(receipt)) {
throw new IllegalStateException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]");
} else {
}
// 记录receipt
String pt = context.getQuery("pt");
if (StringUtils.hasText(pt)) {
session.getAttributes().put(pt, receipt);
context.setSessionAttribute(pt, receipt);
}
// 获取到用户名
String userName = receipt.getUserName();
if (StringUtils.hasText(parameter.casInitContextClass)) {
// 尝试初始化
if (null != initializer) {
try {
Class<?> cls = Class.forName(parameter.casInitContextClass);
Object obj = cls.getConstructor().newInstance();
if (obj instanceof IContextInit) {
Method translatorMethod = cls.getMethod("getTranslatorUser", String.class);
userName = (String) translatorMethod.invoke(obj, userName);
Method initContextMethod = cls.getMethod("initContext", ServletRequest.class, ServletResponse.class, FilterChain.class, String.class);
initContextMethod.invoke(obj, request, response, fc, userName);
}
} catch (ClassNotFoundException | IllegalArgumentException | IllegalAccessException |
InstantiationException | SecurityException | NoSuchMethodException e) {
e.printStackTrace();
} catch (InvocationTargetException var19) {
InvocationTargetException e = var19;
String translated = initializer.getTranslatorUser(userName);
log.debug("translated username: {} to {}", userName, translated);
initializer.initContext(context, translated);
} catch (Exception e) {
String cause = e.getCause().getMessage();
session.setAttribute("initFailure", cause);
this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause);
e.printStackTrace();
return Mono.empty();
context.setSessionAttribute("initFailure", cause);
return this.redirectToInitFailure(context, cause);
}
}
session.setAttribute("edu.yale.its.tp.cas.client.filter.user", userName);
session.setAttribute("edu.yale.its.tp.cas.client.filter.receipt", receipt);
session.removeAttribute(CAS_FILTER_GATEWAYED);
sessionAttributes.put(CAS_FILTER_USER, userName);
sessionAttributes.put(CAS_FILTER_RECEIPT, receipt);
sessionAttributes.remove(CAS_FILTER_GATEWAYED);
if (log.isTraceEnabled()) {
log.trace("validated ticket to get authenticated receipt [{}], now passing request along filter chain.", receipt);
log.trace("returning from doFilter()");
}
return context.filter();
}
return Mono.empty();
} else {
// 不存在票据跳转验证
log.trace("CAS ticket was not present on request.");
boolean didGateway = Boolean.valueOf((String) session.getAttribute(CAS_FILTER_GATEWAYED));
boolean didGateway = Boolean.parseBoolean(session.getAttribute(CAS_FILTER_GATEWAYED));
if (parameter.casLogin == null) {
log.error("casLogin was not set, so filter cannot redirect request for authentication.");
throw new IllegalArgumentException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter");
} else if (!didGateway) {
}
if (!didGateway) {
log.trace("Did not previously gateway. Setting session attribute to true.");
((HttpServletRequest) request).setAttribute("sessionId", session.getId());
session.setAttribute(CAS_FILTER_GATEWAYED, "true");
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response);
} else {
sessionAttributes.put(CAS_FILTER_GATEWAYED, "true");
return this.redirectToCAS(context);
}
log.trace("Previously gatewayed.");
if (!this.casGateway && session.getAttribute("edu.yale.its.tp.cas.client.filter.user") == null) {
if (!parameter.casGateway && session.getAttribute(CAS_FILTER_USER) == null) {
if (session.getAttribute("initFailure") != null) {
String cause = (String) session.getAttribute("initFailure");
this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause);
} else {
((HttpServletRequest) request).setAttribute("sessionId", session.getId());
session.setAttribute(CAS_FILTER_GATEWAYED, "true");
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response);
String cause = session.getAttribute("initFailure");
return this.redirectToInitFailure(context, cause);
}
sessionAttributes.put(CAS_FILTER_GATEWAYED, "true");
return this.redirectToCAS(context);
}
} else {
log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain.");
fc.doFilter((ServletRequest) request, response);
}
}
}
}
return context.filter();
}
/**
@ -306,19 +358,30 @@ public class CASFilter implements WebFilter {
* @return 结果
*/
private Mono<Void> handle(CASContext context) {
// 优先处理token请求
if (context.isTokenRequest()) {
String sessionId = context.getSession().getId();
log.debug("Storing session identifier for {}", sessionId);
// 包括ticket尝试重新替换session
return sessionMappingStorage.removeBySessionById(sessionId)
.onErrorContinue((e, v) -> log.debug("error when remove session"))
.then(Mono.defer(() -> sessionMappingStorage.addSessionById(context.getTicket(), context.getSession())
.then(translate(context))));
}
// post请求需要特殊处理
if (context.getMethod() == HttpMethod.POST) {
// 此处可能要求安全的获取参数单独针对退出请求
// 通过form表单获取注销请求处理注销逻辑
return context.getFormData("logoutRequest")
.doOnNext(payload -> log.trace("Logout request=[{}]", payload))
.defaultIfEmpty("")
.flatMap(payload -> {
if (StringUtils.hasText(payload)) {
String sessionIdentifier = XmlUtils.getTextForElement(payload, "SessionIndex");
if (StringUtils.hasText(sessionIdentifier)) {
String token = XmlUtils.getTextForElement(payload, "SessionIndex");
if (StringUtils.hasText(token)) {
// 满足条件时断路
return SESSION_MAPPING_STORAGE.removeSessionByMappingId(sessionIdentifier)
.doOnNext(session -> log.debug("Invalidating session [{}] for ST [{}]", session.getId(), sessionIdentifier))
return sessionMappingStorage.removeSessionByMappingId(token)
.doOnNext(session -> log.debug("Invalidating session [{}] for ST [{}]", session.getId(), token))
.flatMap(WebSession::invalidate)
.doOnError(IllegalStateException.class, e -> log.debug(e.getMessage(), e))
.onErrorComplete();
@ -327,21 +390,9 @@ public class CASFilter implements WebFilter {
// 继续执行
return translate(context);
});
} else {
String ticket = context.getTicket();
String sessionId = context.getSession().getId();
log.debug("Storing session identifier for {}", sessionId);
// 包括ticket尝试重新替换session
if (StringUtils.hasText(ticket)) {
return SESSION_MAPPING_STORAGE.removeBySessionById(sessionId)
.onErrorResume(e -> SESSION_MAPPING_STORAGE.addSessionById(ticket, context.getSession()))
.then(Mono.defer(() -> translate(context)));
}
return translate(context);
}
}
@Override
@ -353,7 +404,7 @@ public class CASFilter implements WebFilter {
if (log.isTraceEnabled()) {
log.trace("entering doFilter()");
}
// 执行中断策略
// 执行跳过策略
String pt = context.getQuery("pt");
if (StringUtils.hasText(pt)) {
if (session.getAttribute(pt) != null) {

View File

@ -0,0 +1,20 @@
package dev.flyfish.boot.cas.filter;
/**
* 登录过滤器旨在缓存用户名
*
* @author wangyu
*/
public class CASLoginFilter implements CASContextInit {
public static String CONST_CAS_USERNAME = "const_cas_username";
@Override
public String getTranslatorUser(String username) {
return username;
}
@Override
public void initContext(CASContext casContext, String username) {
casContext.setSessionAttribute(CONST_CAS_USERNAME, username);
}
}

View File

@ -31,7 +31,7 @@ public class CASParameter {
String casProxyCallbackUrl;
@JsonAlias(CASFilter.CAS_FILTER_INITCONTEXTCLASS)
String casInitContextClass;
Class<?> casInitContextClass;
@JsonAlias(CASFilter.RENEW_INIT_PARAM)
boolean casRenew;
@ -63,7 +63,7 @@ public class CASParameter {
/**
* 检查配置参数是否有误
*/
public void check() {
public CASParameter checked() {
if (this.casGateway && this.casRenew) {
throw new IllegalArgumentException("gateway and renew cannot both be true in filter configuration");
} else if (this.casServerName != null && this.casServiceUrl != null) {
@ -73,5 +73,6 @@ public class CASParameter {
} else if (this.casValidate == null) {
throw new IllegalArgumentException("validateUrl parameter must be set.");
}
return this;
}
}

View File

@ -1 +0,0 @@
spring.application.name=cas

View File

@ -0,0 +1,13 @@
server:
port: 8080
spring:
application:
name: cas
cas:
filter:
cas-login: https://sdsfzt.sxu.edu.cn/authserver/login
cas-validate: https://sdsfzt.sxu.edu.cn/authserver/serviceValidate
cas-server-name: 127.0.0.1:8080
cas-init-context-class: dev.flyfish.boot.cas.filter.CASLoginFilter