Feat(ktm-booking): Initial commit
Some checks failed
ktm-booking-bot/ktm-booking-bot/pipeline/head Something is wrong with the build of this commit
Some checks failed
ktm-booking-bot/ktm-booking-bot/pipeline/head Something is wrong with the build of this commit
This commit is contained in:
50
backend/internal/user/main.go
Normal file
50
backend/internal/user/main.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Env struct {
|
||||
DB *gorm.DB
|
||||
CookieString string
|
||||
}
|
||||
|
||||
func UserRoutes(db *gorm.DB) chi.Router {
|
||||
var env Env
|
||||
env.DB = db
|
||||
env.CookieString = os.Getenv("COOKIE_STRING")
|
||||
if env.CookieString == "" {
|
||||
env.CookieString = "cookie_string"
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
allowRegistration := os.Getenv("ALLOW_REGISTRATION")
|
||||
if strings.ToUpper(allowRegistration) == "TRUE" || allowRegistration == "1" {
|
||||
log.Println("Registration enabled.")
|
||||
r.Post("/register", env.registerRouteHandler)
|
||||
}
|
||||
r.Post("/login", env.loginRouteHandler)
|
||||
r.Post("/logout", env.logoutRouteHandler)
|
||||
|
||||
checkLoggedInUserGroup := r.Group(nil)
|
||||
checkLoggedInUserGroup.Use(env.CheckUserMiddleware)
|
||||
checkLoggedInUserGroup.Get("/me", env.meRouteHandler)
|
||||
checkLoggedInUserGroup.Put("/profile", env.setProfileRouteHandler)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func NewUserEnv(db *gorm.DB) *Env {
|
||||
var env Env
|
||||
env.DB = db
|
||||
env.CookieString = os.Getenv("COOKIE_STRING")
|
||||
if env.CookieString == "" {
|
||||
env.CookieString = "cooking_string"
|
||||
}
|
||||
return &env
|
||||
}
|
||||
146
backend/internal/user/main_test.go
Normal file
146
backend/internal/user/main_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
|
||||
)
|
||||
|
||||
func TestUserRoutesRegistration(t *testing.T) {
|
||||
testCases := []struct {
|
||||
allowRegistrationFlag bool
|
||||
allowRegistrationFlagSet bool
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
allowRegistrationFlag: true,
|
||||
allowRegistrationFlagSet: true,
|
||||
statusCode: 201,
|
||||
},
|
||||
{
|
||||
allowRegistrationFlag: false,
|
||||
allowRegistrationFlagSet: true,
|
||||
statusCode: 404,
|
||||
},
|
||||
{
|
||||
allowRegistrationFlag: false,
|
||||
allowRegistrationFlagSet: false,
|
||||
statusCode: 404,
|
||||
},
|
||||
}
|
||||
|
||||
db := common.TestDBInit()
|
||||
defer common.DestroyTestingDB(db)
|
||||
db.AutoMigrate(&User{})
|
||||
db.AutoMigrate(&Profile{})
|
||||
db.AutoMigrate(&Session{})
|
||||
|
||||
for _, currTestCase := range testCases {
|
||||
if currTestCase.allowRegistrationFlagSet {
|
||||
if currTestCase.allowRegistrationFlag {
|
||||
t.Setenv("ALLOW_REGISTRATION", "true")
|
||||
} else {
|
||||
t.Setenv("ALLOW_REGISTRATION", "false")
|
||||
}
|
||||
} else {
|
||||
t.Setenv("ALLOW_REGISTRATION", "")
|
||||
}
|
||||
router := UserRoutes(db)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
reqBody := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "testusername",
|
||||
Password: "testpassword",
|
||||
}
|
||||
|
||||
reqBodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader := bytes.NewReader(reqBodyBytes)
|
||||
|
||||
req, err := http.NewRequest("POST", "/register", reqBodyReader)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request: %v", err)
|
||||
}
|
||||
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
currStatusCode := rr.Result().StatusCode
|
||||
if currStatusCode != currTestCase.statusCode {
|
||||
t.Errorf("Wrong status code: Expected - %d. Got - %d", currTestCase.statusCode, rr.Result().StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserRoutesLoggedIn(t *testing.T) {
|
||||
db := common.TestDBInit()
|
||||
defer common.DestroyTestingDB(db)
|
||||
db.AutoMigrate(&User{})
|
||||
db.AutoMigrate(&Profile{})
|
||||
db.AutoMigrate(&Session{})
|
||||
|
||||
t.Setenv("ALLOW_REGISTRATION", "true")
|
||||
router := UserRoutes(db)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Testing /me
|
||||
req, err := http.NewRequest("GET", "/me", nil)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request: %v", err)
|
||||
}
|
||||
router.ServeHTTP(rr, req)
|
||||
if rr.Result().StatusCode != 500 {
|
||||
t.Errorf("Wrong status code: Expected - %d. Got - %d", 500, rr.Result().StatusCode)
|
||||
}
|
||||
|
||||
responseBytes, err := ioutil.ReadAll(rr.Result().Body)
|
||||
if err != nil {
|
||||
t.Errorf("Error reading response body: %v", err)
|
||||
}
|
||||
|
||||
var results map[string]any
|
||||
err = json.Unmarshal(responseBytes, &results)
|
||||
if err != nil {
|
||||
t.Errorf("Error decoding response body: %v", err)
|
||||
}
|
||||
|
||||
if results["error"] != "user not logged in" {
|
||||
t.Errorf("Error for %s not correct", "/me")
|
||||
}
|
||||
|
||||
// Testing /profile
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
emptyJSON := bytes.NewReader([]byte("{}"))
|
||||
req, err = http.NewRequest("PUT", "/profile", emptyJSON)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request: %v", err)
|
||||
}
|
||||
router.ServeHTTP(rr, req)
|
||||
if rr.Result().StatusCode != 500 {
|
||||
t.Errorf("Wrong status code: Expected - %d. Got - %d", 500, rr.Result().StatusCode)
|
||||
}
|
||||
|
||||
responseBytes, err = ioutil.ReadAll(rr.Result().Body)
|
||||
if err != nil {
|
||||
t.Errorf("Error reading response body: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(responseBytes, &results)
|
||||
if err != nil {
|
||||
t.Errorf("Error decoding response body: %v", err)
|
||||
}
|
||||
|
||||
if results["error"] != "user not logged in" {
|
||||
t.Errorf("Error for %s not correct", "/me")
|
||||
}
|
||||
}
|
||||
29
backend/internal/user/profilecontroller.go
Normal file
29
backend/internal/user/profilecontroller.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (env *Env) setProfile(currUser *User, ktmTrainUsername string, ktmTrainPassword string, ktmTrainCreditCardType string, ktmTrainCreditCard string, ktmTrainCreditCardExpiry string, ktmTrainCreditCardCVV string) (*User, error) {
|
||||
profile := &Profile{
|
||||
UserID: currUser.ID,
|
||||
KtmTrainUsername: ktmTrainUsername,
|
||||
KtmTrainPassword: ktmTrainPassword,
|
||||
KtmTrainCreditCardType: ktmTrainCreditCardType,
|
||||
KtmTrainCreditCard: ktmTrainCreditCard,
|
||||
KtmTrainCreditCardExpiry: ktmTrainCreditCardExpiry,
|
||||
KtmTrainCreditCardCVV: ktmTrainCreditCardCVV,
|
||||
}
|
||||
|
||||
if err := env.DB.Clauses(clause.OnConflict{
|
||||
UpdateAll: true,
|
||||
}).Create(profile).Error; err != nil {
|
||||
log.Println("Error creating profile", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currUser.Profile = *profile
|
||||
return currUser, nil
|
||||
}
|
||||
52
backend/internal/user/profilemodel.go
Normal file
52
backend/internal/user/profilemodel.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Profile struct {
|
||||
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
UserID uuid.UUID `gorm:"index"`
|
||||
KtmTrainUsername string
|
||||
KtmTrainPassword string
|
||||
KtmTrainCreditCardType string // Visa/Mastercard
|
||||
KtmTrainCreditCard string
|
||||
KtmTrainCreditCardExpiry string
|
||||
KtmTrainCreditCardCVV string
|
||||
}
|
||||
|
||||
type ProfileRequest struct {
|
||||
KtmTrainUsername string `json:"ktmTrainUsername"`
|
||||
KtmTrainPassword string `json:"ktmTrainPassword"`
|
||||
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
|
||||
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
|
||||
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
|
||||
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
|
||||
}
|
||||
|
||||
type ProfileResponse struct {
|
||||
KtmTrainUsername string `json:"ktmTrainUsername"`
|
||||
KtmTrainPassword string `json:"ktmTrainPassword"`
|
||||
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
|
||||
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
|
||||
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
|
||||
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
|
||||
}
|
||||
|
||||
func (env *Env) NewProfileResponse(model *Profile) *ProfileResponse {
|
||||
res := &ProfileResponse{
|
||||
KtmTrainUsername: model.KtmTrainUsername,
|
||||
KtmTrainPassword: model.KtmTrainPassword,
|
||||
KtmTrainCreditCardType: model.KtmTrainCreditCardType,
|
||||
KtmTrainCreditCard: model.KtmTrainCreditCard,
|
||||
KtmTrainCreditCardExpiry: model.KtmTrainCreditCardExpiry,
|
||||
KtmTrainCreditCardCVV: model.KtmTrainCreditCardCVV,
|
||||
}
|
||||
return res
|
||||
}
|
||||
45
backend/internal/user/profileroute.go
Normal file
45
backend/internal/user/profileroute.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
// Set current user profile
|
||||
// @Summary For setting current user profile
|
||||
// @Description Description
|
||||
// @Tags User
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param user body ProfileRequest true "User registration info"
|
||||
// @Success 200 {object} UserResponse
|
||||
// @Failure 400 {object} common.ErrResponse
|
||||
// @Router /api/v1/user/profile [put]
|
||||
func (env *Env) setProfileRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
data := &ProfileRequest{}
|
||||
err := render.DecodeJSON(r.Body, data)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
currUser, ok := ctx.Value(UserContextKey).(*User)
|
||||
if !ok {
|
||||
err := errors.New("user not logged in")
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
currUser, err = env.setProfile(currUser, data.KtmTrainUsername, data.KtmTrainPassword, data.KtmTrainCreditCardType, data.KtmTrainCreditCard, data.KtmTrainCreditCardExpiry, data.KtmTrainCreditCardCVV)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.Render(w, r, env.NewUserResponse(currUser))
|
||||
}
|
||||
174
backend/internal/user/profileroute_test.go
Normal file
174
backend/internal/user/profileroute_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
|
||||
)
|
||||
|
||||
func TestSetProfile(t *testing.T) {
|
||||
db := common.TestDBInit()
|
||||
defer common.DestroyTestingDB(db)
|
||||
db.AutoMigrate(&User{})
|
||||
db.AutoMigrate(&Profile{})
|
||||
db.AutoMigrate(&Session{})
|
||||
t.Setenv("ALLOW_REGISTRATION", "true")
|
||||
t.Setenv("COOKIE_STRING", "supercustomcookie")
|
||||
router := UserRoutes(db)
|
||||
testCases := []struct {
|
||||
ktmTrainUsername string
|
||||
ktmTrainPassword string
|
||||
ktmTrainCreditCardType string
|
||||
ktmTrainCreditCard string
|
||||
ktmTrainCreditCardExpiry string
|
||||
ktmTrainCreditCardCVV string
|
||||
statusCode int
|
||||
cookieEnabled bool
|
||||
}{
|
||||
{
|
||||
ktmTrainUsername: "test",
|
||||
ktmTrainPassword: "test",
|
||||
ktmTrainCreditCardType: "Visa",
|
||||
ktmTrainCreditCard: "1234567890123456",
|
||||
ktmTrainCreditCardExpiry: "05/2025",
|
||||
ktmTrainCreditCardCVV: "123",
|
||||
statusCode: 200,
|
||||
cookieEnabled: true,
|
||||
},
|
||||
{
|
||||
ktmTrainUsername: "",
|
||||
ktmTrainPassword: "",
|
||||
ktmTrainCreditCardType: "",
|
||||
ktmTrainCreditCard: "",
|
||||
ktmTrainCreditCardExpiry: "",
|
||||
ktmTrainCreditCardCVV: "",
|
||||
statusCode: 200,
|
||||
cookieEnabled: true,
|
||||
},
|
||||
{
|
||||
ktmTrainUsername: "test",
|
||||
ktmTrainPassword: "test",
|
||||
ktmTrainCreditCardType: "Visa",
|
||||
ktmTrainCreditCard: "1234567890123456",
|
||||
ktmTrainCreditCardExpiry: "05/2025",
|
||||
ktmTrainCreditCardCVV: "123",
|
||||
statusCode: 500,
|
||||
cookieEnabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Register user
|
||||
rr := httptest.NewRecorder()
|
||||
currBody := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "testusername",
|
||||
Password: "testpassword",
|
||||
}
|
||||
reqBody, err := json.Marshal(currBody)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader := bytes.NewReader(reqBody)
|
||||
req := httptest.NewRequest("POST", "/register", reqBodyReader)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
// Check registration results
|
||||
if rr.Code != 201 {
|
||||
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
|
||||
}
|
||||
|
||||
// Login to get cookie
|
||||
rr = httptest.NewRecorder()
|
||||
currBodyLogin := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "testusername",
|
||||
Password: "testpassword",
|
||||
}
|
||||
reqBody, err = json.Marshal(currBodyLogin)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader = bytes.NewReader(reqBody)
|
||||
req = httptest.NewRequest("POST", "/login", reqBodyReader)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
for _, currentTestCase := range testCases {
|
||||
// Start checking Profile
|
||||
rrProfile := httptest.NewRecorder()
|
||||
currBody := struct {
|
||||
KtmTrainUsername string `json:"ktmTrainUsername"`
|
||||
KtmTrainPassword string `json:"ktmTrainPassword"`
|
||||
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
|
||||
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
|
||||
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
|
||||
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
|
||||
}{
|
||||
KtmTrainUsername: currentTestCase.ktmTrainUsername,
|
||||
KtmTrainPassword: currentTestCase.ktmTrainPassword,
|
||||
KtmTrainCreditCardType: currentTestCase.ktmTrainCreditCardType,
|
||||
KtmTrainCreditCard: currentTestCase.ktmTrainCreditCard,
|
||||
KtmTrainCreditCardExpiry: currentTestCase.ktmTrainCreditCardExpiry,
|
||||
KtmTrainCreditCardCVV: currentTestCase.ktmTrainCreditCardCVV,
|
||||
}
|
||||
reqBody, err = json.Marshal(currBody)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader = bytes.NewReader(reqBody)
|
||||
req = httptest.NewRequest("PUT", "/profile", reqBodyReader)
|
||||
if currentTestCase.cookieEnabled {
|
||||
req.AddCookie(rr.Result().Cookies()[0])
|
||||
}
|
||||
router.ServeHTTP(rrProfile, req)
|
||||
|
||||
// Check Profile results
|
||||
if rrProfile.Code != currentTestCase.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrProfile.Code)
|
||||
}
|
||||
|
||||
if currentTestCase.statusCode == 200 {
|
||||
var resultsObj map[string]any
|
||||
resBodyBytes, err := ioutil.ReadAll(rrProfile.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Error reading response body: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resBodyBytes, &resultsObj)
|
||||
if err != nil {
|
||||
t.Errorf("Error unmarshalling response body: %v", err)
|
||||
}
|
||||
|
||||
resultsProfile := resultsObj["profile"].(map[string]any)
|
||||
|
||||
if resultsObj["username"] != "testusername" {
|
||||
t.Errorf("Expected username %s, got %s", "testusername", resultsObj["username"])
|
||||
}
|
||||
if resultsProfile["ktmTrainUsername"] != currentTestCase.ktmTrainUsername {
|
||||
t.Errorf("Expected ktmTrainUsername %s, got %s", currentTestCase.ktmTrainUsername, resultsProfile["ktmTrainUsername"])
|
||||
}
|
||||
if resultsProfile["ktmTrainPassword"] != currentTestCase.ktmTrainPassword {
|
||||
t.Errorf("Expected ktmTrainPassword %s, got %s", currentTestCase.ktmTrainPassword, resultsProfile["ktmTrainPassword"])
|
||||
}
|
||||
if resultsProfile["ktmTrainCreditCardType"] != currentTestCase.ktmTrainCreditCardType {
|
||||
t.Errorf("Expected ktmTrainCreditCardType %s, got %s", currentTestCase.ktmTrainCreditCardType, resultsProfile["ktmTrainCreditCardType"])
|
||||
}
|
||||
if resultsProfile["ktmTrainCreditCard"] != currentTestCase.ktmTrainCreditCard {
|
||||
t.Errorf("Expected ktmTrainCreditCard %s, got %s", currentTestCase.ktmTrainCreditCard, resultsProfile["ktmTrainCreditCard"])
|
||||
}
|
||||
if resultsProfile["ktmTrainCreditCardExpiry"] != currentTestCase.ktmTrainCreditCardExpiry {
|
||||
t.Errorf("Expected ktmTrainCreditCardExpiry %s, got %s", currentTestCase.ktmTrainCreditCardExpiry, resultsProfile["ktmTrainCreditCardExpiry"])
|
||||
}
|
||||
if resultsProfile["ktmTrainCreditCardCVV"] != currentTestCase.ktmTrainCreditCardCVV {
|
||||
t.Errorf("Expected ktmTrainCreditCardCVV %s, got %s", currentTestCase.ktmTrainCreditCardCVV, resultsProfile["ktmTrainCreditCardCVV"])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
56
backend/internal/user/sessioncontroller.go
Normal file
56
backend/internal/user/sessioncontroller.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (env *Env) createSession(user *User) (string, error) {
|
||||
var newSession Session
|
||||
newSession.UserID = user.ID
|
||||
newSession.SessionToken = uuid.New().String()
|
||||
|
||||
err := env.DB.Create(&newSession).Error
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return "", errors.New("failed write to database")
|
||||
}
|
||||
|
||||
return newSession.SessionToken, nil
|
||||
}
|
||||
|
||||
func (env *Env) getUserFromSessionToken(sessionToken string) (*User, error) {
|
||||
var currUser User
|
||||
err := env.DB.Table("sessions").Select("users.*").Joins("left join users on users.id = sessions.user_id").Where("sessions.session_token = ?", sessionToken).First(&currUser).Error
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, errors.New("failed get user")
|
||||
}
|
||||
|
||||
err = env.DB.Preload("Profile").Where(&currUser).First(&currUser).Error
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, errors.New("failed get user")
|
||||
}
|
||||
|
||||
return &currUser, nil
|
||||
}
|
||||
|
||||
func (env *Env) logout(sessionToken string) error {
|
||||
var currSession Session
|
||||
err := env.DB.Where(&Session{SessionToken: sessionToken}).First(&currSession).Error
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return errors.New("failed get session")
|
||||
}
|
||||
|
||||
err = env.DB.Delete(&currSession).Error
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return errors.New("failed to logout")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
29
backend/internal/user/sessionmiddleware.go
Normal file
29
backend/internal/user/sessionmiddleware.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
func (env *Env) CheckUserMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(env.CookieString)
|
||||
if err != nil {
|
||||
err = errors.New("user not logged in")
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
user, err := env.getUserFromSessionToken(cookie.Value)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), UserContextKey, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
7
backend/internal/user/sessionmiddlewarecontext.go
Normal file
7
backend/internal/user/sessionmiddlewarecontext.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package user
|
||||
|
||||
const (
|
||||
UserContextKey ContextKey = "user"
|
||||
)
|
||||
|
||||
type ContextKey string
|
||||
17
backend/internal/user/sessionmodel.go
Normal file
17
backend/internal/user/sessionmodel.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
SessionToken string
|
||||
UserID uuid.UUID
|
||||
}
|
||||
71
backend/internal/user/usercontroller.go
Normal file
71
backend/internal/user/usercontroller.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
BCRYPTCOST = 12
|
||||
)
|
||||
|
||||
func (env *Env) createUser(username string, password string) (*User, error) {
|
||||
var createdUser User
|
||||
|
||||
passwordByte := []byte(password)
|
||||
createdHashBytes, err := bcrypt.GenerateFromPassword(passwordByte, BCRYPTCOST)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to generate bcrypt")
|
||||
}
|
||||
|
||||
createdUser.Username = username
|
||||
createdUser.PasswordBcrypt = string(createdHashBytes)
|
||||
return &createdUser, nil
|
||||
}
|
||||
|
||||
func (env *Env) registerUser(username string, password string) (*User, error) {
|
||||
newUser, err := env.createUser(username, password)
|
||||
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, errors.New("failed to register user")
|
||||
}
|
||||
|
||||
// Check existing username
|
||||
var checkUser User
|
||||
env.DB.Where(&User{Username: username}).First(&checkUser)
|
||||
if checkUser.ID != uuid.Nil {
|
||||
log.Println(err)
|
||||
return nil, errors.New("user already exists")
|
||||
}
|
||||
|
||||
err = env.DB.Create(newUser).Error
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, errors.New("failed write to database")
|
||||
}
|
||||
|
||||
return newUser, nil
|
||||
}
|
||||
|
||||
func (env *Env) checkLogin(username string, password string) (*User, error) {
|
||||
var currUser User
|
||||
env.DB.Preload("Profile").Where(&User{Username: username}).First(&currUser)
|
||||
|
||||
// Prevent username enum by parsing password
|
||||
if currUser.ID == uuid.Nil {
|
||||
bcrypt.GenerateFromPassword([]byte{}, BCRYPTCOST)
|
||||
return nil, errors.New("invalid username or password")
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(currUser.PasswordBcrypt), []byte(password))
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid username or password")
|
||||
} else {
|
||||
return &currUser, nil
|
||||
}
|
||||
}
|
||||
52
backend/internal/user/usermodel.go
Normal file
52
backend/internal/user/usermodel.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
Username string
|
||||
PasswordBcrypt string
|
||||
Profile Profile
|
||||
Sessions []Session
|
||||
}
|
||||
|
||||
type UserResponse struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Profile *ProfileResponse `json:"profile"`
|
||||
}
|
||||
|
||||
type UserRegisterRequest struct {
|
||||
Username string `json:"username" validate:"required,min=2,max=100"`
|
||||
Password string `json:"password" validate:"required,min=6,max=100"`
|
||||
}
|
||||
|
||||
type UserLoginRequest struct {
|
||||
Username string `json:"username" validate:"required,min=2,max=100"`
|
||||
Password string `json:"password" validate:"required"`
|
||||
}
|
||||
|
||||
func (userResponse *UserResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
// Pre-processing before a response is marshalled and sent across the wire
|
||||
return nil
|
||||
}
|
||||
|
||||
func (env *Env) NewUserResponse(user *User) *UserResponse {
|
||||
profileResponse := env.NewProfileResponse(&user.Profile)
|
||||
userResponse := &UserResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Profile: profileResponse,
|
||||
}
|
||||
|
||||
return userResponse
|
||||
}
|
||||
137
backend/internal/user/userroute.go
Normal file
137
backend/internal/user/userroute.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// User Register
|
||||
// @Summary For user registration
|
||||
// @Description Description
|
||||
// @Tags User
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param user body UserRegisterRequest true "User registration info"
|
||||
// @Success 200 {object} UserResponse
|
||||
// @Failure 400 {object} common.ErrResponse
|
||||
// @Router /api/v1/user/register [post]
|
||||
func (env *Env) registerRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
data := &UserRegisterRequest{}
|
||||
err := render.DecodeJSON(r.Body, data)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = validator.New().Struct(data)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrValidationError(err))
|
||||
return
|
||||
}
|
||||
|
||||
createdUser, err := env.registerUser(data.Username, data.Password)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.Status(r, http.StatusCreated)
|
||||
render.Render(w, r, env.NewUserResponse(createdUser))
|
||||
}
|
||||
|
||||
// Login
|
||||
// @Summary For user login
|
||||
// @Description Description
|
||||
// @Tags User
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param user body UserLoginRequest true "User Login info"
|
||||
// @Success 200 {object} UserResponse
|
||||
// @Failure 400 {object} common.ErrResponse
|
||||
// @Router /api/v1/user/login [post]
|
||||
func (env *Env) loginRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
data := &UserLoginRequest{}
|
||||
err := render.DecodeJSON(r.Body, data)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInvalidRequest(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = validator.New().Struct(data)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrValidationError(err))
|
||||
return
|
||||
}
|
||||
|
||||
loginUser, err := env.checkLogin(data.Username, data.Password)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken, err := env.createSession(loginUser)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
loginCookie := http.Cookie{
|
||||
Name: env.CookieString,
|
||||
Value: sessionToken,
|
||||
MaxAge: 7776000,
|
||||
Path: "/",
|
||||
}
|
||||
http.SetCookie(w, &loginCookie)
|
||||
|
||||
render.Render(w, r, env.NewUserResponse(loginUser))
|
||||
}
|
||||
|
||||
// Logout
|
||||
// @Summary For user logout
|
||||
// @Description Description
|
||||
// @Tags User
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} common.TextResponse
|
||||
// @Failure 400 {object} common.ErrResponse
|
||||
// @Router /api/v1/user/logout [post]
|
||||
func (env *Env) logoutRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(env.CookieString)
|
||||
if err != nil {
|
||||
err = errors.New("user not logged in")
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = env.logout(cookie.Value)
|
||||
if err != nil {
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.Render(w, r, common.NewGenericTextResponse("Ok", "Successfully logged out"))
|
||||
}
|
||||
|
||||
// Check current user
|
||||
// @Summary Returns current logged in user
|
||||
// @Description Description
|
||||
// @Tags User
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} UserResponse
|
||||
// @Failure 400 {object} common.ErrResponse
|
||||
// @Router /api/v1/user/me [get]
|
||||
func (env *Env) meRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
currUser, ok := ctx.Value(UserContextKey).(*User)
|
||||
if !ok {
|
||||
err := errors.New("user not logged in")
|
||||
render.Render(w, r, common.ErrInternalError(err))
|
||||
return
|
||||
}
|
||||
render.Render(w, r, env.NewUserResponse(currUser))
|
||||
}
|
||||
270
backend/internal/user/userroute_test.go
Normal file
270
backend/internal/user/userroute_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
|
||||
)
|
||||
|
||||
func TestRegistration(t *testing.T) {
|
||||
db := common.TestDBInit()
|
||||
defer common.DestroyTestingDB(db)
|
||||
db.AutoMigrate(&User{})
|
||||
db.AutoMigrate(&Profile{})
|
||||
db.AutoMigrate(&Session{})
|
||||
t.Setenv("ALLOW_REGISTRATION", "true")
|
||||
router := UserRoutes(db)
|
||||
testCases := []struct {
|
||||
username string
|
||||
password string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
username: "testusername",
|
||||
password: "testpassword",
|
||||
statusCode: 201,
|
||||
},
|
||||
{
|
||||
username: "",
|
||||
password: "testpassword",
|
||||
statusCode: 422,
|
||||
},
|
||||
{
|
||||
username: "testusername",
|
||||
password: "",
|
||||
statusCode: 422,
|
||||
},
|
||||
{
|
||||
username: "t",
|
||||
password: "testpassword",
|
||||
statusCode: 422,
|
||||
},
|
||||
{
|
||||
username: "testusername",
|
||||
password: "test",
|
||||
statusCode: 422,
|
||||
},
|
||||
}
|
||||
for _, currentTestCase := range testCases {
|
||||
rr := httptest.NewRecorder()
|
||||
currBody := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: currentTestCase.username,
|
||||
Password: currentTestCase.password,
|
||||
}
|
||||
reqBody, err := json.Marshal(currBody)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader := bytes.NewReader(reqBody)
|
||||
req := httptest.NewRequest("POST", "/register", reqBodyReader)
|
||||
router.ServeHTTP(rr, req)
|
||||
// Check results
|
||||
if rr.Code != currentTestCase.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
db := common.TestDBInit()
|
||||
defer common.DestroyTestingDB(db)
|
||||
db.AutoMigrate(&User{})
|
||||
db.AutoMigrate(&Profile{})
|
||||
db.AutoMigrate(&Session{})
|
||||
t.Setenv("ALLOW_REGISTRATION", "true")
|
||||
t.Setenv("COOKIE_STRING", "supercustomcookie")
|
||||
router := UserRoutes(db)
|
||||
testCases := []struct {
|
||||
username string
|
||||
password string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
username: "testusername",
|
||||
password: "testpassword",
|
||||
statusCode: 200,
|
||||
},
|
||||
{
|
||||
username: "",
|
||||
password: "testpassword",
|
||||
statusCode: 422,
|
||||
},
|
||||
{
|
||||
username: "testusername",
|
||||
password: "",
|
||||
statusCode: 422,
|
||||
},
|
||||
{
|
||||
username: "t",
|
||||
password: "testpassword",
|
||||
statusCode: 422,
|
||||
},
|
||||
{
|
||||
username: "testusername",
|
||||
password: "test",
|
||||
statusCode: 500,
|
||||
},
|
||||
{
|
||||
username: "testusername",
|
||||
password: "",
|
||||
statusCode: 422,
|
||||
},
|
||||
}
|
||||
|
||||
// Register user
|
||||
rr := httptest.NewRecorder()
|
||||
currBody := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "testusername",
|
||||
Password: "testpassword",
|
||||
}
|
||||
reqBody, err := json.Marshal(currBody)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader := bytes.NewReader(reqBody)
|
||||
req := httptest.NewRequest("POST", "/register", reqBodyReader)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
// Check registration results
|
||||
if rr.Code != 201 {
|
||||
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
|
||||
}
|
||||
for _, currentTestCase := range testCases {
|
||||
|
||||
// Start checking login
|
||||
rrLogin := httptest.NewRecorder()
|
||||
currBody = struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: currentTestCase.username,
|
||||
Password: currentTestCase.password,
|
||||
}
|
||||
reqBody, err = json.Marshal(currBody)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader = bytes.NewReader(reqBody)
|
||||
req = httptest.NewRequest("POST", "/login", reqBodyReader)
|
||||
router.ServeHTTP(rrLogin, req)
|
||||
|
||||
// Check login results
|
||||
if rrLogin.Code != currentTestCase.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrLogin.Code)
|
||||
}
|
||||
|
||||
if rrLogin.Code == 200 {
|
||||
if rrLogin.Header().Get("Set-Cookie") == "" {
|
||||
t.Errorf("Expected a cookie to be set, but it wasn't")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMe(t *testing.T) {
|
||||
db := common.TestDBInit()
|
||||
defer common.DestroyTestingDB(db)
|
||||
db.AutoMigrate(&User{})
|
||||
db.AutoMigrate(&Profile{})
|
||||
db.AutoMigrate(&Session{})
|
||||
t.Setenv("ALLOW_REGISTRATION", "true")
|
||||
t.Setenv("COOKIE_STRING", "supercustomcookie")
|
||||
router := UserRoutes(db)
|
||||
testCases := []struct {
|
||||
cookieEnabled bool
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
cookieEnabled: true,
|
||||
statusCode: 200,
|
||||
},
|
||||
{
|
||||
cookieEnabled: false,
|
||||
statusCode: 500,
|
||||
},
|
||||
}
|
||||
|
||||
// Register user
|
||||
rr := httptest.NewRecorder()
|
||||
currBody := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "testusername",
|
||||
Password: "testpassword",
|
||||
}
|
||||
reqBody, err := json.Marshal(currBody)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader := bytes.NewReader(reqBody)
|
||||
req := httptest.NewRequest("POST", "/register", reqBodyReader)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
// Check registration results
|
||||
if rr.Code != 201 {
|
||||
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
|
||||
}
|
||||
|
||||
// Login to get cookie
|
||||
rr = httptest.NewRecorder()
|
||||
currBodyLogin := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "testusername",
|
||||
Password: "testpassword",
|
||||
}
|
||||
reqBody, err = json.Marshal(currBodyLogin)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
reqBodyReader = bytes.NewReader(reqBody)
|
||||
req = httptest.NewRequest("POST", "/login", reqBodyReader)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
for _, currentTestCase := range testCases {
|
||||
// Start checking Profile
|
||||
rrProfile := httptest.NewRecorder()
|
||||
if err != nil {
|
||||
t.Errorf("Error creating a new request body: %v", err)
|
||||
}
|
||||
req = httptest.NewRequest("GET", "/me", nil)
|
||||
if currentTestCase.cookieEnabled {
|
||||
req.AddCookie(rr.Result().Cookies()[0])
|
||||
}
|
||||
router.ServeHTTP(rrProfile, req)
|
||||
|
||||
// Check Profile results
|
||||
if rrProfile.Code != currentTestCase.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrProfile.Code)
|
||||
}
|
||||
|
||||
if currentTestCase.statusCode == 200 {
|
||||
var resultsObj map[string]any
|
||||
resBodyBytes, err := ioutil.ReadAll(rrProfile.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Error reading response body: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resBodyBytes, &resultsObj)
|
||||
if err != nil {
|
||||
t.Errorf("Error unmarshalling response body: %v", err)
|
||||
}
|
||||
|
||||
if resultsObj["username"] != "testusername" {
|
||||
t.Errorf("Expected username %s, got %s", "testusername", resultsObj["username"])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user