Skip to content

Commit

Permalink
refactor: AuthenticationHelper
Browse files Browse the repository at this point in the history
  • Loading branch information
Handiwork committed Mar 15, 2024
1 parent 1bef800 commit 092890e
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import org.springframework.stereotype.Component
import org.springframework.web.context.request.RequestContextHolder
import org.springframework.web.context.request.ServletRequestAttributes
import org.springframework.web.server.ResponseStatusException
import plus.maa.backend.common.utils.IpUtil.getIpAddr
import plus.maa.backend.common.utils.IpUtil
import plus.maa.backend.service.jwt.JwtAuthToken
import plus.maa.backend.service.model.LoginUser
import java.util.*
Expand All @@ -35,39 +35,35 @@ class AuthenticationHelper {
*/
@Throws(ResponseStatusException::class)
fun requireUserId(): String {
val id = userId ?: throw ResponseStatusException(HttpStatus.UNAUTHORIZED)
return id
return obtainUserId() ?: throw ResponseStatusException(HttpStatus.UNAUTHORIZED)
}

val userId: String?
/**
* 获取用户 id
*
* @return 用户 id,如未验证则返回 null
*/
get() {
val auth = SecurityContextHolder.getContext().authentication ?: return null
if (auth is UsernamePasswordAuthenticationToken) {
val principal = auth.getPrincipal()
if (principal is LoginUser) return principal.userId
} else if (auth is JwtAuthToken) {
return auth.subject
}
return null
/**
* 获取用户 id
*
* @return 用户 id,如未验证则返回 null
*/
fun obtainUserId(): String? {
val auth = SecurityContextHolder.getContext().authentication ?: return null
if (auth is UsernamePasswordAuthenticationToken) {
val user = auth.getPrincipal() as? LoginUser
return user?.userId
} else if (auth is JwtAuthToken) {
return auth.subject
}
return null
}

val userIdOrIpAddress: String
/**
* 获取已验证用户 id 或者未验证用户 ip 地址。在 HTTP request 之外调用该方法获取 ip 会抛出 NPE
*
* @return 用户 id 或者 ip 地址
*/
get() {
val id = userId
if (id != null) return id
/**
* 获取已验证用户 id 或者未验证用户 ip 地址。在 HTTP request 之外调用该方法获取 ip 会抛出 [IllegalStateException]
*
* @return 用户 id 或者 ip 地址
*/
fun obtainUserIdOrIpAddress(): String {
val id = obtainUserId()
if (id != null) return id

val attributes = Objects.requireNonNull(RequestContextHolder.getRequestAttributes())
val request = (attributes as ServletRequestAttributes).request
return getIpAddr(request)
}
val request = (RequestContextHolder.getRequestAttributes() as? ServletRequestAttributes)?.request
return checkNotNull(request).run(IpUtil::getIpAddr)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CopilotController(
fun getCopilotById(
@Parameter(description = "作业id") @PathVariable("id") id: Long
): MaaResult<CopilotInfo?> {
val userIdOrIpAddress = helper.userIdOrIpAddress
val userIdOrIpAddress = helper.obtainUserIdOrIpAddress()
return copilotService.getCopilotById(userIdOrIpAddress, id)?.let { success(it) }
?: fail(404, "作业不存在")
}
Expand All @@ -76,7 +76,7 @@ class CopilotController(
): MaaResult<CopilotPageInfo> {
// 三秒防抖,缓解前端重复请求问题
response.setHeader(HttpHeaders.CACHE_CONTROL, "private, max-age=3, must-revalidate")
return success(copilotService.queriesCopilot(helper.userId, parsed))
return success(copilotService.queriesCopilot(helper.obtainUserId(), parsed))
}

@Operation(summary = "更新作业")
Expand All @@ -95,7 +95,7 @@ class CopilotController(
@JsonSchema
@PostMapping("/rating")
fun ratesCopilotOperation(@RequestBody copilotRatingReq: CopilotRatingReq): MaaResult<String> {
copilotService.rates(helper.userIdOrIpAddress, copilotRatingReq)
copilotService.rates(helper.obtainUserIdOrIpAddress(), copilotRatingReq)
return success("评分成功")
}

Expand All @@ -104,7 +104,7 @@ class CopilotController(
@ApiResponse(description = "success")
@GetMapping("/status")
fun modifyStatus(@RequestParam id: @NotBlank Long, @RequestParam status: Boolean): MaaResult<String> {
copilotService.notificationStatus(helper.userId, id, status)
copilotService.notificationStatus(helper.obtainUserId(), id, status)
return success("success")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class CopilotSetController(
@RequireJwt
@PostMapping("/create")
fun createSet(@RequestBody req: @Valid CopilotSetCreateReq): MaaResult<Long> {
return success(service.create(req, helper.userId))
return success(service.create(req, helper.obtainUserId()))
}

@Operation(summary = "添加作业集作业列表")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class FileController(
@RequestPart(required = false) classification: String?,
@RequestPart(required = false) label: String
): MaaResult<String> {
fileService.uploadFile(file, type, version, classification, label, helper.userIdOrIpAddress)
fileService.uploadFile(file, type, version, classification, label, helper.obtainUserIdOrIpAddress())
return success("上传成功,数据已被接收")
}

Expand Down
8 changes: 0 additions & 8 deletions src/main/kotlin/plus/maa/backend/repository/entity/MaaUser.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import org.springframework.data.annotation.Id
import org.springframework.data.annotation.Transient
import org.springframework.data.mongodb.core.index.Indexed
import org.springframework.data.mongodb.core.mapping.Document
import plus.maa.backend.controller.request.user.UserInfoUpdateDTO
import java.io.Serializable

/**
Expand All @@ -25,13 +24,6 @@ data class MaaUser(
var refreshJwtIds: MutableList<String> = ArrayList()
) : Serializable {

fun updateAttribute(updateDTO: UserInfoUpdateDTO) {
val userName = updateDTO.userName
if (userName.isNotBlank()) {
this.userName = userName
}
}

companion object {
@Transient
val UNKNOWN: MaaUser = MaaUser(
Expand Down
30 changes: 11 additions & 19 deletions src/main/kotlin/plus/maa/backend/service/UserService.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package plus.maa.backend.service

import org.springframework.beans.BeanUtils
import org.springframework.dao.DuplicateKeyException
import org.springframework.data.repository.findByIdOrNull
import org.springframework.security.crypto.password.PasswordEncoder
import org.springframework.stereotype.Service
import plus.maa.backend.common.MaaStatusCode
Expand Down Expand Up @@ -78,9 +78,7 @@ class UserService(
originPassword: String? = null,
verifyOriginPassword: Boolean = true
) {
val userResult = userRepository.findById(userId)
if (userResult.isEmpty) return
val maaUser = userResult.get()
val maaUser = userRepository.findByIdOrNull(userId) ?: return
if (verifyOriginPassword) {
check(!originPassword.isNullOrEmpty()) {
"请输入原密码"
Expand All @@ -102,27 +100,22 @@ class UserService(
* @return 返回注册成功的用户摘要(脱敏)
*/
fun register(registerDTO: RegisterDTO): MaaUserInfo {
val encode = passwordEncoder.encode(registerDTO.password)

// 校验验证码
emailService.verifyVCode(registerDTO.email, registerDTO.registrationToken)

val encoded = passwordEncoder.encode(registerDTO.password)

val user = MaaUser(
userName = registerDTO.userName,
email = registerDTO.email,
password = registerDTO.password
password = encoded,
status = 1,
)
BeanUtils.copyProperties(registerDTO, user)
user.password = encode
user.status = 1
val userInfo: MaaUserInfo
try {
val save = userRepository.save(user)
userInfo = MaaUserInfo(save)
return try {
userRepository.save(user).run(::MaaUserInfo)
} catch (e: DuplicateKeyException) {
throw MaaResultException(MaaStatusCode.MAA_USER_EXISTS)
}
return userInfo
}

/**
Expand All @@ -132,10 +125,9 @@ class UserService(
* @param updateDTO 更新参数
*/
fun updateUserInfo(userId: String, updateDTO: UserInfoUpdateDTO) {
userRepository.findById(userId).ifPresent { maaUser: MaaUser ->
maaUser.updateAttribute(updateDTO)
userRepository.save(maaUser)
}
val maaUser = userRepository.findByIdOrNull(userId) ?: return
maaUser.userName = updateDTO.userName
userRepository.save(maaUser)
}

/**
Expand Down

0 comments on commit 092890e

Please sign in to comment.