Feat(ktm-booking): Initial commit
Some checks failed
ktm-booking-bot/ktm-booking-bot/pipeline/head Something is wrong with the build of this commit

This commit is contained in:
2022-09-27 02:50:07 +08:00
commit 7cf10b07d4
44 changed files with 4569 additions and 0 deletions

View File

@@ -0,0 +1,50 @@
package user
import (
"log"
"os"
"strings"
"github.com/go-chi/chi"
"gorm.io/gorm"
)
type Env struct {
DB *gorm.DB
CookieString string
}
func UserRoutes(db *gorm.DB) chi.Router {
var env Env
env.DB = db
env.CookieString = os.Getenv("COOKIE_STRING")
if env.CookieString == "" {
env.CookieString = "cookie_string"
}
r := chi.NewRouter()
allowRegistration := os.Getenv("ALLOW_REGISTRATION")
if strings.ToUpper(allowRegistration) == "TRUE" || allowRegistration == "1" {
log.Println("Registration enabled.")
r.Post("/register", env.registerRouteHandler)
}
r.Post("/login", env.loginRouteHandler)
r.Post("/logout", env.logoutRouteHandler)
checkLoggedInUserGroup := r.Group(nil)
checkLoggedInUserGroup.Use(env.CheckUserMiddleware)
checkLoggedInUserGroup.Get("/me", env.meRouteHandler)
checkLoggedInUserGroup.Put("/profile", env.setProfileRouteHandler)
return r
}
func NewUserEnv(db *gorm.DB) *Env {
var env Env
env.DB = db
env.CookieString = os.Getenv("COOKIE_STRING")
if env.CookieString == "" {
env.CookieString = "cooking_string"
}
return &env
}

View File

@@ -0,0 +1,146 @@
package user
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
)
func TestUserRoutesRegistration(t *testing.T) {
testCases := []struct {
allowRegistrationFlag bool
allowRegistrationFlagSet bool
statusCode int
}{
{
allowRegistrationFlag: true,
allowRegistrationFlagSet: true,
statusCode: 201,
},
{
allowRegistrationFlag: false,
allowRegistrationFlagSet: true,
statusCode: 404,
},
{
allowRegistrationFlag: false,
allowRegistrationFlagSet: false,
statusCode: 404,
},
}
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
for _, currTestCase := range testCases {
if currTestCase.allowRegistrationFlagSet {
if currTestCase.allowRegistrationFlag {
t.Setenv("ALLOW_REGISTRATION", "true")
} else {
t.Setenv("ALLOW_REGISTRATION", "false")
}
} else {
t.Setenv("ALLOW_REGISTRATION", "")
}
router := UserRoutes(db)
rr := httptest.NewRecorder()
reqBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBodyBytes, err := json.Marshal(reqBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBodyBytes)
req, err := http.NewRequest("POST", "/register", reqBodyReader)
if err != nil {
t.Errorf("Error creating a new request: %v", err)
}
router.ServeHTTP(rr, req)
currStatusCode := rr.Result().StatusCode
if currStatusCode != currTestCase.statusCode {
t.Errorf("Wrong status code: Expected - %d. Got - %d", currTestCase.statusCode, rr.Result().StatusCode)
}
}
}
func TestUserRoutesLoggedIn(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
router := UserRoutes(db)
rr := httptest.NewRecorder()
// Testing /me
req, err := http.NewRequest("GET", "/me", nil)
if err != nil {
t.Errorf("Error creating a new request: %v", err)
}
router.ServeHTTP(rr, req)
if rr.Result().StatusCode != 500 {
t.Errorf("Wrong status code: Expected - %d. Got - %d", 500, rr.Result().StatusCode)
}
responseBytes, err := ioutil.ReadAll(rr.Result().Body)
if err != nil {
t.Errorf("Error reading response body: %v", err)
}
var results map[string]any
err = json.Unmarshal(responseBytes, &results)
if err != nil {
t.Errorf("Error decoding response body: %v", err)
}
if results["error"] != "user not logged in" {
t.Errorf("Error for %s not correct", "/me")
}
// Testing /profile
rr = httptest.NewRecorder()
emptyJSON := bytes.NewReader([]byte("{}"))
req, err = http.NewRequest("PUT", "/profile", emptyJSON)
if err != nil {
t.Errorf("Error creating a new request: %v", err)
}
router.ServeHTTP(rr, req)
if rr.Result().StatusCode != 500 {
t.Errorf("Wrong status code: Expected - %d. Got - %d", 500, rr.Result().StatusCode)
}
responseBytes, err = ioutil.ReadAll(rr.Result().Body)
if err != nil {
t.Errorf("Error reading response body: %v", err)
}
err = json.Unmarshal(responseBytes, &results)
if err != nil {
t.Errorf("Error decoding response body: %v", err)
}
if results["error"] != "user not logged in" {
t.Errorf("Error for %s not correct", "/me")
}
}

View File

@@ -0,0 +1,29 @@
package user
import (
"log"
"gorm.io/gorm/clause"
)
func (env *Env) setProfile(currUser *User, ktmTrainUsername string, ktmTrainPassword string, ktmTrainCreditCardType string, ktmTrainCreditCard string, ktmTrainCreditCardExpiry string, ktmTrainCreditCardCVV string) (*User, error) {
profile := &Profile{
UserID: currUser.ID,
KtmTrainUsername: ktmTrainUsername,
KtmTrainPassword: ktmTrainPassword,
KtmTrainCreditCardType: ktmTrainCreditCardType,
KtmTrainCreditCard: ktmTrainCreditCard,
KtmTrainCreditCardExpiry: ktmTrainCreditCardExpiry,
KtmTrainCreditCardCVV: ktmTrainCreditCardCVV,
}
if err := env.DB.Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(profile).Error; err != nil {
log.Println("Error creating profile", err)
return nil, err
}
currUser.Profile = *profile
return currUser, nil
}

View File

@@ -0,0 +1,52 @@
package user
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type Profile struct {
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
UserID uuid.UUID `gorm:"index"`
KtmTrainUsername string
KtmTrainPassword string
KtmTrainCreditCardType string // Visa/Mastercard
KtmTrainCreditCard string
KtmTrainCreditCardExpiry string
KtmTrainCreditCardCVV string
}
type ProfileRequest struct {
KtmTrainUsername string `json:"ktmTrainUsername"`
KtmTrainPassword string `json:"ktmTrainPassword"`
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
}
type ProfileResponse struct {
KtmTrainUsername string `json:"ktmTrainUsername"`
KtmTrainPassword string `json:"ktmTrainPassword"`
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
}
func (env *Env) NewProfileResponse(model *Profile) *ProfileResponse {
res := &ProfileResponse{
KtmTrainUsername: model.KtmTrainUsername,
KtmTrainPassword: model.KtmTrainPassword,
KtmTrainCreditCardType: model.KtmTrainCreditCardType,
KtmTrainCreditCard: model.KtmTrainCreditCard,
KtmTrainCreditCardExpiry: model.KtmTrainCreditCardExpiry,
KtmTrainCreditCardCVV: model.KtmTrainCreditCardCVV,
}
return res
}

View File

@@ -0,0 +1,45 @@
package user
import (
"errors"
"net/http"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
"github.com/go-chi/render"
)
// Set current user profile
// @Summary For setting current user profile
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Param user body ProfileRequest true "User registration info"
// @Success 200 {object} UserResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/profile [put]
func (env *Env) setProfileRouteHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := &ProfileRequest{}
err := render.DecodeJSON(r.Body, data)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
currUser, ok := ctx.Value(UserContextKey).(*User)
if !ok {
err := errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
currUser, err = env.setProfile(currUser, data.KtmTrainUsername, data.KtmTrainPassword, data.KtmTrainCreditCardType, data.KtmTrainCreditCard, data.KtmTrainCreditCardExpiry, data.KtmTrainCreditCardCVV)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
render.Render(w, r, env.NewUserResponse(currUser))
}

View File

@@ -0,0 +1,174 @@
package user
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http/httptest"
"testing"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
)
func TestSetProfile(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
t.Setenv("COOKIE_STRING", "supercustomcookie")
router := UserRoutes(db)
testCases := []struct {
ktmTrainUsername string
ktmTrainPassword string
ktmTrainCreditCardType string
ktmTrainCreditCard string
ktmTrainCreditCardExpiry string
ktmTrainCreditCardCVV string
statusCode int
cookieEnabled bool
}{
{
ktmTrainUsername: "test",
ktmTrainPassword: "test",
ktmTrainCreditCardType: "Visa",
ktmTrainCreditCard: "1234567890123456",
ktmTrainCreditCardExpiry: "05/2025",
ktmTrainCreditCardCVV: "123",
statusCode: 200,
cookieEnabled: true,
},
{
ktmTrainUsername: "",
ktmTrainPassword: "",
ktmTrainCreditCardType: "",
ktmTrainCreditCard: "",
ktmTrainCreditCardExpiry: "",
ktmTrainCreditCardCVV: "",
statusCode: 200,
cookieEnabled: true,
},
{
ktmTrainUsername: "test",
ktmTrainPassword: "test",
ktmTrainCreditCardType: "Visa",
ktmTrainCreditCard: "1234567890123456",
ktmTrainCreditCardExpiry: "05/2025",
ktmTrainCreditCardCVV: "123",
statusCode: 500,
cookieEnabled: false,
},
}
// Register user
rr := httptest.NewRecorder()
currBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err := json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBody)
req := httptest.NewRequest("POST", "/register", reqBodyReader)
router.ServeHTTP(rr, req)
// Check registration results
if rr.Code != 201 {
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
}
// Login to get cookie
rr = httptest.NewRecorder()
currBodyLogin := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err = json.Marshal(currBodyLogin)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader = bytes.NewReader(reqBody)
req = httptest.NewRequest("POST", "/login", reqBodyReader)
router.ServeHTTP(rr, req)
for _, currentTestCase := range testCases {
// Start checking Profile
rrProfile := httptest.NewRecorder()
currBody := struct {
KtmTrainUsername string `json:"ktmTrainUsername"`
KtmTrainPassword string `json:"ktmTrainPassword"`
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
}{
KtmTrainUsername: currentTestCase.ktmTrainUsername,
KtmTrainPassword: currentTestCase.ktmTrainPassword,
KtmTrainCreditCardType: currentTestCase.ktmTrainCreditCardType,
KtmTrainCreditCard: currentTestCase.ktmTrainCreditCard,
KtmTrainCreditCardExpiry: currentTestCase.ktmTrainCreditCardExpiry,
KtmTrainCreditCardCVV: currentTestCase.ktmTrainCreditCardCVV,
}
reqBody, err = json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader = bytes.NewReader(reqBody)
req = httptest.NewRequest("PUT", "/profile", reqBodyReader)
if currentTestCase.cookieEnabled {
req.AddCookie(rr.Result().Cookies()[0])
}
router.ServeHTTP(rrProfile, req)
// Check Profile results
if rrProfile.Code != currentTestCase.statusCode {
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrProfile.Code)
}
if currentTestCase.statusCode == 200 {
var resultsObj map[string]any
resBodyBytes, err := ioutil.ReadAll(rrProfile.Body)
if err != nil {
t.Errorf("Error reading response body: %v", err)
}
err = json.Unmarshal(resBodyBytes, &resultsObj)
if err != nil {
t.Errorf("Error unmarshalling response body: %v", err)
}
resultsProfile := resultsObj["profile"].(map[string]any)
if resultsObj["username"] != "testusername" {
t.Errorf("Expected username %s, got %s", "testusername", resultsObj["username"])
}
if resultsProfile["ktmTrainUsername"] != currentTestCase.ktmTrainUsername {
t.Errorf("Expected ktmTrainUsername %s, got %s", currentTestCase.ktmTrainUsername, resultsProfile["ktmTrainUsername"])
}
if resultsProfile["ktmTrainPassword"] != currentTestCase.ktmTrainPassword {
t.Errorf("Expected ktmTrainPassword %s, got %s", currentTestCase.ktmTrainPassword, resultsProfile["ktmTrainPassword"])
}
if resultsProfile["ktmTrainCreditCardType"] != currentTestCase.ktmTrainCreditCardType {
t.Errorf("Expected ktmTrainCreditCardType %s, got %s", currentTestCase.ktmTrainCreditCardType, resultsProfile["ktmTrainCreditCardType"])
}
if resultsProfile["ktmTrainCreditCard"] != currentTestCase.ktmTrainCreditCard {
t.Errorf("Expected ktmTrainCreditCard %s, got %s", currentTestCase.ktmTrainCreditCard, resultsProfile["ktmTrainCreditCard"])
}
if resultsProfile["ktmTrainCreditCardExpiry"] != currentTestCase.ktmTrainCreditCardExpiry {
t.Errorf("Expected ktmTrainCreditCardExpiry %s, got %s", currentTestCase.ktmTrainCreditCardExpiry, resultsProfile["ktmTrainCreditCardExpiry"])
}
if resultsProfile["ktmTrainCreditCardCVV"] != currentTestCase.ktmTrainCreditCardCVV {
t.Errorf("Expected ktmTrainCreditCardCVV %s, got %s", currentTestCase.ktmTrainCreditCardCVV, resultsProfile["ktmTrainCreditCardCVV"])
}
}
}
}

View File

@@ -0,0 +1,56 @@
package user
import (
"errors"
"log"
"github.com/google/uuid"
)
func (env *Env) createSession(user *User) (string, error) {
var newSession Session
newSession.UserID = user.ID
newSession.SessionToken = uuid.New().String()
err := env.DB.Create(&newSession).Error
if err != nil {
log.Println(err)
return "", errors.New("failed write to database")
}
return newSession.SessionToken, nil
}
func (env *Env) getUserFromSessionToken(sessionToken string) (*User, error) {
var currUser User
err := env.DB.Table("sessions").Select("users.*").Joins("left join users on users.id = sessions.user_id").Where("sessions.session_token = ?", sessionToken).First(&currUser).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed get user")
}
err = env.DB.Preload("Profile").Where(&currUser).First(&currUser).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed get user")
}
return &currUser, nil
}
func (env *Env) logout(sessionToken string) error {
var currSession Session
err := env.DB.Where(&Session{SessionToken: sessionToken}).First(&currSession).Error
if err != nil {
log.Println(err)
return errors.New("failed get session")
}
err = env.DB.Delete(&currSession).Error
if err != nil {
log.Println(err)
return errors.New("failed to logout")
}
return nil
}

View File

@@ -0,0 +1,29 @@
package user
import (
"context"
"errors"
"net/http"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
"github.com/go-chi/render"
)
func (env *Env) CheckUserMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(env.CookieString)
if err != nil {
err = errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
user, err := env.getUserFromSessionToken(cookie.Value)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
ctx := context.WithValue(r.Context(), UserContextKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@@ -0,0 +1,7 @@
package user
const (
UserContextKey ContextKey = "user"
)
type ContextKey string

View File

@@ -0,0 +1,17 @@
package user
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type Session struct {
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
SessionToken string
UserID uuid.UUID
}

View File

@@ -0,0 +1,71 @@
package user
import (
"errors"
"log"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
const (
BCRYPTCOST = 12
)
func (env *Env) createUser(username string, password string) (*User, error) {
var createdUser User
passwordByte := []byte(password)
createdHashBytes, err := bcrypt.GenerateFromPassword(passwordByte, BCRYPTCOST)
if err != nil {
return nil, errors.New("failed to generate bcrypt")
}
createdUser.Username = username
createdUser.PasswordBcrypt = string(createdHashBytes)
return &createdUser, nil
}
func (env *Env) registerUser(username string, password string) (*User, error) {
newUser, err := env.createUser(username, password)
if err != nil {
log.Println(err)
return nil, errors.New("failed to register user")
}
// Check existing username
var checkUser User
env.DB.Where(&User{Username: username}).First(&checkUser)
if checkUser.ID != uuid.Nil {
log.Println(err)
return nil, errors.New("user already exists")
}
err = env.DB.Create(newUser).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed write to database")
}
return newUser, nil
}
func (env *Env) checkLogin(username string, password string) (*User, error) {
var currUser User
env.DB.Preload("Profile").Where(&User{Username: username}).First(&currUser)
// Prevent username enum by parsing password
if currUser.ID == uuid.Nil {
bcrypt.GenerateFromPassword([]byte{}, BCRYPTCOST)
return nil, errors.New("invalid username or password")
}
err := bcrypt.CompareHashAndPassword([]byte(currUser.PasswordBcrypt), []byte(password))
if err != nil {
return nil, errors.New("invalid username or password")
} else {
return &currUser, nil
}
}

View File

@@ -0,0 +1,52 @@
package user
import (
"net/http"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type User struct {
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
Username string
PasswordBcrypt string
Profile Profile
Sessions []Session
}
type UserResponse struct {
ID uuid.UUID `json:"id"`
Username string `json:"username"`
Profile *ProfileResponse `json:"profile"`
}
type UserRegisterRequest struct {
Username string `json:"username" validate:"required,min=2,max=100"`
Password string `json:"password" validate:"required,min=6,max=100"`
}
type UserLoginRequest struct {
Username string `json:"username" validate:"required,min=2,max=100"`
Password string `json:"password" validate:"required"`
}
func (userResponse *UserResponse) Render(w http.ResponseWriter, r *http.Request) error {
// Pre-processing before a response is marshalled and sent across the wire
return nil
}
func (env *Env) NewUserResponse(user *User) *UserResponse {
profileResponse := env.NewProfileResponse(&user.Profile)
userResponse := &UserResponse{
ID: user.ID,
Username: user.Username,
Profile: profileResponse,
}
return userResponse
}

View File

@@ -0,0 +1,137 @@
package user
import (
"errors"
"net/http"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
"github.com/go-chi/render"
"github.com/go-playground/validator/v10"
)
// User Register
// @Summary For user registration
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Param user body UserRegisterRequest true "User registration info"
// @Success 200 {object} UserResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/register [post]
func (env *Env) registerRouteHandler(w http.ResponseWriter, r *http.Request) {
data := &UserRegisterRequest{}
err := render.DecodeJSON(r.Body, data)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
err = validator.New().Struct(data)
if err != nil {
render.Render(w, r, common.ErrValidationError(err))
return
}
createdUser, err := env.registerUser(data.Username, data.Password)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
render.Status(r, http.StatusCreated)
render.Render(w, r, env.NewUserResponse(createdUser))
}
// Login
// @Summary For user login
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Param user body UserLoginRequest true "User Login info"
// @Success 200 {object} UserResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/login [post]
func (env *Env) loginRouteHandler(w http.ResponseWriter, r *http.Request) {
data := &UserLoginRequest{}
err := render.DecodeJSON(r.Body, data)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
err = validator.New().Struct(data)
if err != nil {
render.Render(w, r, common.ErrValidationError(err))
return
}
loginUser, err := env.checkLogin(data.Username, data.Password)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
sessionToken, err := env.createSession(loginUser)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
loginCookie := http.Cookie{
Name: env.CookieString,
Value: sessionToken,
MaxAge: 7776000,
Path: "/",
}
http.SetCookie(w, &loginCookie)
render.Render(w, r, env.NewUserResponse(loginUser))
}
// Logout
// @Summary For user logout
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Success 200 {object} common.TextResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/logout [post]
func (env *Env) logoutRouteHandler(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(env.CookieString)
if err != nil {
err = errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
err = env.logout(cookie.Value)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
render.Render(w, r, common.NewGenericTextResponse("Ok", "Successfully logged out"))
}
// Check current user
// @Summary Returns current logged in user
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Success 200 {object} UserResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/me [get]
func (env *Env) meRouteHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
currUser, ok := ctx.Value(UserContextKey).(*User)
if !ok {
err := errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
render.Render(w, r, env.NewUserResponse(currUser))
}

View File

@@ -0,0 +1,270 @@
package user
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http/httptest"
"testing"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
)
func TestRegistration(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
router := UserRoutes(db)
testCases := []struct {
username string
password string
statusCode int
}{
{
username: "testusername",
password: "testpassword",
statusCode: 201,
},
{
username: "",
password: "testpassword",
statusCode: 422,
},
{
username: "testusername",
password: "",
statusCode: 422,
},
{
username: "t",
password: "testpassword",
statusCode: 422,
},
{
username: "testusername",
password: "test",
statusCode: 422,
},
}
for _, currentTestCase := range testCases {
rr := httptest.NewRecorder()
currBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: currentTestCase.username,
Password: currentTestCase.password,
}
reqBody, err := json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBody)
req := httptest.NewRequest("POST", "/register", reqBodyReader)
router.ServeHTTP(rr, req)
// Check results
if rr.Code != currentTestCase.statusCode {
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rr.Code)
}
}
}
func TestLogin(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
t.Setenv("COOKIE_STRING", "supercustomcookie")
router := UserRoutes(db)
testCases := []struct {
username string
password string
statusCode int
}{
{
username: "testusername",
password: "testpassword",
statusCode: 200,
},
{
username: "",
password: "testpassword",
statusCode: 422,
},
{
username: "testusername",
password: "",
statusCode: 422,
},
{
username: "t",
password: "testpassword",
statusCode: 422,
},
{
username: "testusername",
password: "test",
statusCode: 500,
},
{
username: "testusername",
password: "",
statusCode: 422,
},
}
// Register user
rr := httptest.NewRecorder()
currBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err := json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBody)
req := httptest.NewRequest("POST", "/register", reqBodyReader)
router.ServeHTTP(rr, req)
// Check registration results
if rr.Code != 201 {
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
}
for _, currentTestCase := range testCases {
// Start checking login
rrLogin := httptest.NewRecorder()
currBody = struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: currentTestCase.username,
Password: currentTestCase.password,
}
reqBody, err = json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader = bytes.NewReader(reqBody)
req = httptest.NewRequest("POST", "/login", reqBodyReader)
router.ServeHTTP(rrLogin, req)
// Check login results
if rrLogin.Code != currentTestCase.statusCode {
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrLogin.Code)
}
if rrLogin.Code == 200 {
if rrLogin.Header().Get("Set-Cookie") == "" {
t.Errorf("Expected a cookie to be set, but it wasn't")
}
}
}
}
func TestGetMe(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
t.Setenv("COOKIE_STRING", "supercustomcookie")
router := UserRoutes(db)
testCases := []struct {
cookieEnabled bool
statusCode int
}{
{
cookieEnabled: true,
statusCode: 200,
},
{
cookieEnabled: false,
statusCode: 500,
},
}
// Register user
rr := httptest.NewRecorder()
currBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err := json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBody)
req := httptest.NewRequest("POST", "/register", reqBodyReader)
router.ServeHTTP(rr, req)
// Check registration results
if rr.Code != 201 {
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
}
// Login to get cookie
rr = httptest.NewRecorder()
currBodyLogin := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err = json.Marshal(currBodyLogin)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader = bytes.NewReader(reqBody)
req = httptest.NewRequest("POST", "/login", reqBodyReader)
router.ServeHTTP(rr, req)
for _, currentTestCase := range testCases {
// Start checking Profile
rrProfile := httptest.NewRecorder()
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
req = httptest.NewRequest("GET", "/me", nil)
if currentTestCase.cookieEnabled {
req.AddCookie(rr.Result().Cookies()[0])
}
router.ServeHTTP(rrProfile, req)
// Check Profile results
if rrProfile.Code != currentTestCase.statusCode {
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrProfile.Code)
}
if currentTestCase.statusCode == 200 {
var resultsObj map[string]any
resBodyBytes, err := ioutil.ReadAll(rrProfile.Body)
if err != nil {
t.Errorf("Error reading response body: %v", err)
}
err = json.Unmarshal(resBodyBytes, &resultsObj)
if err != nil {
t.Errorf("Error unmarshalling response body: %v", err)
}
if resultsObj["username"] != "testusername" {
t.Errorf("Expected username %s, got %s", "testusername", resultsObj["username"])
}
}
}
}