From 092890ee1967462566dc81be1daca03aa06d5eec Mon Sep 17 00:00:00 2001 From: Handiwork Date: Fri, 15 Mar 2024 23:39:34 +0800 Subject: [PATCH] refactor: AuthenticationHelper --- .../config/security/AuthenticationHelper.kt | 58 +++++++++---------- .../backend/controller/CopilotController.kt | 8 +-- .../controller/CopilotSetController.kt | 2 +- .../backend/controller/file/FileController.kt | 2 +- .../maa/backend/repository/entity/MaaUser.kt | 8 --- .../plus/maa/backend/service/UserService.kt | 30 ++++------ 6 files changed, 44 insertions(+), 64 deletions(-) diff --git a/src/main/kotlin/plus/maa/backend/config/security/AuthenticationHelper.kt b/src/main/kotlin/plus/maa/backend/config/security/AuthenticationHelper.kt index 2499b96f..e811665f 100644 --- a/src/main/kotlin/plus/maa/backend/config/security/AuthenticationHelper.kt +++ b/src/main/kotlin/plus/maa/backend/config/security/AuthenticationHelper.kt @@ -8,7 +8,7 @@ import org.springframework.stereotype.Component import org.springframework.web.context.request.RequestContextHolder import org.springframework.web.context.request.ServletRequestAttributes import org.springframework.web.server.ResponseStatusException -import plus.maa.backend.common.utils.IpUtil.getIpAddr +import plus.maa.backend.common.utils.IpUtil import plus.maa.backend.service.jwt.JwtAuthToken import plus.maa.backend.service.model.LoginUser import java.util.* @@ -35,39 +35,35 @@ class AuthenticationHelper { */ @Throws(ResponseStatusException::class) fun requireUserId(): String { - val id = userId ?: throw ResponseStatusException(HttpStatus.UNAUTHORIZED) - return id + return obtainUserId() ?: throw ResponseStatusException(HttpStatus.UNAUTHORIZED) } - val userId: String? - /** - * 获取用户 id - * - * @return 用户 id,如未验证则返回 null - */ - get() { - val auth = SecurityContextHolder.getContext().authentication ?: return null - if (auth is UsernamePasswordAuthenticationToken) { - val principal = auth.getPrincipal() - if (principal is LoginUser) return principal.userId - } else if (auth is JwtAuthToken) { - return auth.subject - } - return null + /** + * 获取用户 id + * + * @return 用户 id,如未验证则返回 null + */ + fun obtainUserId(): String? { + val auth = SecurityContextHolder.getContext().authentication ?: return null + if (auth is UsernamePasswordAuthenticationToken) { + val user = auth.getPrincipal() as? LoginUser + return user?.userId + } else if (auth is JwtAuthToken) { + return auth.subject } + return null + } - val userIdOrIpAddress: String - /** - * 获取已验证用户 id 或者未验证用户 ip 地址。在 HTTP request 之外调用该方法获取 ip 会抛出 NPE - * - * @return 用户 id 或者 ip 地址 - */ - get() { - val id = userId - if (id != null) return id + /** + * 获取已验证用户 id 或者未验证用户 ip 地址。在 HTTP request 之外调用该方法获取 ip 会抛出 [IllegalStateException] + * + * @return 用户 id 或者 ip 地址 + */ + fun obtainUserIdOrIpAddress(): String { + val id = obtainUserId() + if (id != null) return id - val attributes = Objects.requireNonNull(RequestContextHolder.getRequestAttributes()) - val request = (attributes as ServletRequestAttributes).request - return getIpAddr(request) - } + val request = (RequestContextHolder.getRequestAttributes() as? ServletRequestAttributes)?.request + return checkNotNull(request).run(IpUtil::getIpAddr) + } } diff --git a/src/main/kotlin/plus/maa/backend/controller/CopilotController.kt b/src/main/kotlin/plus/maa/backend/controller/CopilotController.kt index b56c5322..2bc3d86b 100644 --- a/src/main/kotlin/plus/maa/backend/controller/CopilotController.kt +++ b/src/main/kotlin/plus/maa/backend/controller/CopilotController.kt @@ -62,7 +62,7 @@ class CopilotController( fun getCopilotById( @Parameter(description = "作业id") @PathVariable("id") id: Long ): MaaResult { - val userIdOrIpAddress = helper.userIdOrIpAddress + val userIdOrIpAddress = helper.obtainUserIdOrIpAddress() return copilotService.getCopilotById(userIdOrIpAddress, id)?.let { success(it) } ?: fail(404, "作业不存在") } @@ -76,7 +76,7 @@ class CopilotController( ): MaaResult { // 三秒防抖,缓解前端重复请求问题 response.setHeader(HttpHeaders.CACHE_CONTROL, "private, max-age=3, must-revalidate") - return success(copilotService.queriesCopilot(helper.userId, parsed)) + return success(copilotService.queriesCopilot(helper.obtainUserId(), parsed)) } @Operation(summary = "更新作业") @@ -95,7 +95,7 @@ class CopilotController( @JsonSchema @PostMapping("/rating") fun ratesCopilotOperation(@RequestBody copilotRatingReq: CopilotRatingReq): MaaResult { - copilotService.rates(helper.userIdOrIpAddress, copilotRatingReq) + copilotService.rates(helper.obtainUserIdOrIpAddress(), copilotRatingReq) return success("评分成功") } @@ -104,7 +104,7 @@ class CopilotController( @ApiResponse(description = "success") @GetMapping("/status") fun modifyStatus(@RequestParam id: @NotBlank Long, @RequestParam status: Boolean): MaaResult { - copilotService.notificationStatus(helper.userId, id, status) + copilotService.notificationStatus(helper.obtainUserId(), id, status) return success("success") } } diff --git a/src/main/kotlin/plus/maa/backend/controller/CopilotSetController.kt b/src/main/kotlin/plus/maa/backend/controller/CopilotSetController.kt index 30b4bfc2..5b5b7908 100644 --- a/src/main/kotlin/plus/maa/backend/controller/CopilotSetController.kt +++ b/src/main/kotlin/plus/maa/backend/controller/CopilotSetController.kt @@ -50,7 +50,7 @@ class CopilotSetController( @RequireJwt @PostMapping("/create") fun createSet(@RequestBody req: @Valid CopilotSetCreateReq): MaaResult { - return success(service.create(req, helper.userId)) + return success(service.create(req, helper.obtainUserId())) } @Operation(summary = "添加作业集作业列表") diff --git a/src/main/kotlin/plus/maa/backend/controller/file/FileController.kt b/src/main/kotlin/plus/maa/backend/controller/file/FileController.kt index 85226975..801b218c 100644 --- a/src/main/kotlin/plus/maa/backend/controller/file/FileController.kt +++ b/src/main/kotlin/plus/maa/backend/controller/file/FileController.kt @@ -44,7 +44,7 @@ class FileController( @RequestPart(required = false) classification: String?, @RequestPart(required = false) label: String ): MaaResult { - fileService.uploadFile(file, type, version, classification, label, helper.userIdOrIpAddress) + fileService.uploadFile(file, type, version, classification, label, helper.obtainUserIdOrIpAddress()) return success("上传成功,数据已被接收") } diff --git a/src/main/kotlin/plus/maa/backend/repository/entity/MaaUser.kt b/src/main/kotlin/plus/maa/backend/repository/entity/MaaUser.kt index 0cfb032b..cda7b5f2 100644 --- a/src/main/kotlin/plus/maa/backend/repository/entity/MaaUser.kt +++ b/src/main/kotlin/plus/maa/backend/repository/entity/MaaUser.kt @@ -5,7 +5,6 @@ import org.springframework.data.annotation.Id import org.springframework.data.annotation.Transient import org.springframework.data.mongodb.core.index.Indexed import org.springframework.data.mongodb.core.mapping.Document -import plus.maa.backend.controller.request.user.UserInfoUpdateDTO import java.io.Serializable /** @@ -25,13 +24,6 @@ data class MaaUser( var refreshJwtIds: MutableList = ArrayList() ) : Serializable { - fun updateAttribute(updateDTO: UserInfoUpdateDTO) { - val userName = updateDTO.userName - if (userName.isNotBlank()) { - this.userName = userName - } - } - companion object { @Transient val UNKNOWN: MaaUser = MaaUser( diff --git a/src/main/kotlin/plus/maa/backend/service/UserService.kt b/src/main/kotlin/plus/maa/backend/service/UserService.kt index a8fd5cff..afb9f294 100644 --- a/src/main/kotlin/plus/maa/backend/service/UserService.kt +++ b/src/main/kotlin/plus/maa/backend/service/UserService.kt @@ -1,7 +1,7 @@ package plus.maa.backend.service -import org.springframework.beans.BeanUtils import org.springframework.dao.DuplicateKeyException +import org.springframework.data.repository.findByIdOrNull import org.springframework.security.crypto.password.PasswordEncoder import org.springframework.stereotype.Service import plus.maa.backend.common.MaaStatusCode @@ -78,9 +78,7 @@ class UserService( originPassword: String? = null, verifyOriginPassword: Boolean = true ) { - val userResult = userRepository.findById(userId) - if (userResult.isEmpty) return - val maaUser = userResult.get() + val maaUser = userRepository.findByIdOrNull(userId) ?: return if (verifyOriginPassword) { check(!originPassword.isNullOrEmpty()) { "请输入原密码" @@ -102,27 +100,22 @@ class UserService( * @return 返回注册成功的用户摘要(脱敏) */ fun register(registerDTO: RegisterDTO): MaaUserInfo { - val encode = passwordEncoder.encode(registerDTO.password) - // 校验验证码 emailService.verifyVCode(registerDTO.email, registerDTO.registrationToken) + val encoded = passwordEncoder.encode(registerDTO.password) + val user = MaaUser( userName = registerDTO.userName, email = registerDTO.email, - password = registerDTO.password + password = encoded, + status = 1, ) - BeanUtils.copyProperties(registerDTO, user) - user.password = encode - user.status = 1 - val userInfo: MaaUserInfo - try { - val save = userRepository.save(user) - userInfo = MaaUserInfo(save) + return try { + userRepository.save(user).run(::MaaUserInfo) } catch (e: DuplicateKeyException) { throw MaaResultException(MaaStatusCode.MAA_USER_EXISTS) } - return userInfo } /** @@ -132,10 +125,9 @@ class UserService( * @param updateDTO 更新参数 */ fun updateUserInfo(userId: String, updateDTO: UserInfoUpdateDTO) { - userRepository.findById(userId).ifPresent { maaUser: MaaUser -> - maaUser.updateAttribute(updateDTO) - userRepository.save(maaUser) - } + val maaUser = userRepository.findByIdOrNull(userId) ?: return + maaUser.userName = updateDTO.userName + userRepository.save(maaUser) } /**