DeployHelper/internal/middleware/cors.go

74 lines
2.1 KiB
Go
Raw Normal View History

2025-08-01 16:38:08 +08:00
package middleware
import (
_ "ego/docs"
"ego/pkg/logger"
"os"
"regexp"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Cors 跨域配置(支持环境变量动态配置)
func Cors() gin.HandlerFunc {
config := cors.DefaultConfig()
// 设置基础配置
config.AllowMethods = []string{
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
}
config.AllowHeaders = []string{
"Origin", "Content-Length", "Content-Type", "Cookie", "Authorization",
"X-Requested-With", "X-CSRF-Token", // 添加常用自定义头
}
config.AllowCredentials = true
config.MaxAge = 12 * time.Hour // 预检请求缓存12小时
// 根据环境配置允许的来源
if gin.Mode() == gin.ReleaseMode {
// 生产环境:从环境变量读取允许的域名(逗号分隔)
origins := os.Getenv("CORS_ALLOWED_ORIGINS")
if origins == "" {
// 如果环境变量未设置,使用默认的安全配置
logger.Warn(nil, "CORS_ALLOWED_ORIGINS 环境变量未设置,使用默认安全配置")
config.AllowOrigins = []string{
"https://yourdomain.com", // 请替换为实际域名
}
} else {
// 清理空白字符并分割
originList := make([]string, 0)
for _, origin := range strings.Split(origins, ",") {
trimmed := strings.TrimSpace(origin)
if trimmed != "" {
originList = append(originList, trimmed)
}
}
config.AllowOrigins = originList
}
} else {
// 开发环境:匹配本地开发域名
config.AllowOriginFunc = func(origin string) bool {
// 匹配 http://localhost:端口 或 http://127.0.0.1:端口
re := regexp.MustCompile(`^http://(localhost|127\.0\.0\.1):\d+$`)
return re.MatchString(origin)
}
}
// 输出当前生效的配置(方便调试)
if logger.Logger != nil {
logger.Info(nil, "跨域配置初始化完成",
zap.Any("允许方法", config.AllowMethods),
zap.Any("允许头", config.AllowHeaders),
zap.Any("允许来源", config.AllowOrigins),
zap.Duration("MaxAge", config.MaxAge),
zap.String("模式", gin.Mode()),
)
}
return cors.New(config)
}