boxes-api/auth.go

182 lines
4.9 KiB
Go
Raw Normal View History

package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/dgrijalva/jwt-go"
"golang.org/x/crypto/bcrypt"
)
// Define contextKey globally within the package
type contextKey string
// Define your key as a constant of the custom type
const userKey contextKey = "user"
// LoginRequest represents the request body for the /login endpoint.
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
// LoginResponse represents the response body for the /login endpoint.
type LoginResponse struct {
Token string `json:"token"`
}
func init() {
log := GetLogger()
if log != nil {
log.Info("Initializing authentication module")
log.Debug("Current config: %+v", config)
}
}
// LoginHandler handles the /login endpoint.
func LoginHandler(w http.ResponseWriter, r *http.Request) {
log := GetLogger()
var req LoginRequest
if db == nil {
if log != nil {
log.Error("Database connection not initialized in LoginHandler")
}
http.Error(w, "Database not initialized", http.StatusInternalServerError)
return
}
if log != nil {
log.Info("Processing login request")
}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
if log != nil {
log.Error("Failed to decode login request: %v", err)
}
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Check if the user exists and the password matches
var user User
result := db.Where("username = ?", req.Username).First(&user)
if result.Error != nil || user.ID == 0 {
if log != nil {
log.Warn("Login attempt failed for username: %s - user not found", req.Username)
}
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
return
}
// Compare the provided password with the stored hashed password
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password))
if err != nil {
if log != nil {
log.Warn("Login attempt failed for username: %s - invalid password", req.Username)
}
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
return
}
// Generate JWT token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"username": user.Username,
"exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours
})
tokenString, err := token.SignedString(*JWTSecret)
if err != nil {
if log != nil {
log.Error("Failed to generate JWT token for user %s: %v", user.Username, err)
}
http.Error(w, "Failed to generate token", http.StatusInternalServerError)
return
}
if log != nil {
log.Info("Successful login for user: %s", user.Username)
log.UserAction(user.Username, "login")
}
// Return the token in the response
json.NewEncoder(w).Encode(LoginResponse{Token: tokenString})
}
// AuthMiddleware is a middleware function that checks for a valid JWT token in the request header and enables CORS.
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := GetLogger()
// Set CORS headers
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
// Handle preflight request for CORS
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
// Get the token from the request header
tokenString := r.Header.Get("Authorization")
if tokenString == "" {
if log != nil {
log.Warn("Request rejected: missing Authorization header")
}
http.Error(w, "Authorization header missing", http.StatusUnauthorized)
return
}
// Remove "Bearer " prefix from token string
tokenString = strings.Replace(tokenString, "Bearer ", "", 1)
// Parse and validate the JWT token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Make sure that the signing method is HMAC
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
if log != nil {
log.Warn("Invalid signing method in token: %v", token.Header["alg"])
}
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return *JWTSecret, nil
})
if err != nil || !token.Valid {
if log != nil {
log.Warn("Invalid token: %v", err)
}
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
// Extract the user claims from the token
if claims, ok := token.Claims.(jwt.MapClaims); ok {
username := claims["username"].(string)
// Add the "user" claim to the request context
newCtx := context.WithValue(r.Context(), userKey, username)
r = r.WithContext(newCtx)
if log != nil {
log.Debug("Authenticated request for user: %s", username)
}
} else {
if log != nil {
log.Warn("Invalid token claims structure")
}
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return
}
// Call the next handler in the chain
next.ServeHTTP(w, r)
})
}