Feat(ktm-booking): Initial commit
Some checks failed
ktm-booking-bot/ktm-booking-bot/pipeline/head Something is wrong with the build of this commit

This commit is contained in:
2022-09-27 02:50:07 +08:00
commit 7cf10b07d4
44 changed files with 4569 additions and 0 deletions

View File

@@ -0,0 +1,131 @@
package common
import (
"fmt"
"log"
"os"
"strings"
"time"
"github.com/google/uuid"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type Database struct {
*gorm.DB
}
var DB *gorm.DB
// Opening a database and save the reference to `Database` struct.
func InitDB() *gorm.DB {
host := os.Getenv("DB_HOST")
user := os.Getenv("DB_USER")
pass := os.Getenv("DB_PASS")
dbName := os.Getenv("DB_NAME")
port := os.Getenv("DB_PORT")
var sslMode string
if os.Getenv("DB_SSL") == "TRUE" {
sslMode = "enable"
} else {
sslMode = "disable"
}
// DB Logger config
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Silent, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
Colorful: true, // Disable color
},
)
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=Asia/Singapore", host, user, pass, dbName, port, sslMode)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: newLogger,
})
if err != nil {
fmt.Println("db err: (Init) ", err)
}
// Setup UUID
// CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
db.Exec("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";")
DB = db
return DB
}
// This function will create a temporarily database for running testing cases
func TestDBInit() *gorm.DB {
host := os.Getenv("TEST_DB_HOST")
user := os.Getenv("TEST_DB_USER")
pass := os.Getenv("TEST_DB_PASS")
dbName := fmt.Sprintf("%s_%s", os.Getenv("TEST_DB_NAME"), uuid.New().String())
dbName = strings.ReplaceAll(dbName, "-", "_")
port := os.Getenv("TEST_DB_PORT")
var sslMode string
if os.Getenv("TEST_DB_SSL") == "TRUE" {
sslMode = "enable"
} else {
sslMode = "disable"
}
// DB Logger config
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Silent, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
Colorful: true, // Disable color
},
)
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=Asia/Singapore", host, user, pass, "postgres", port, sslMode)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: newLogger,
})
if err != nil {
fmt.Println("db err: (Init) ", err)
}
// Create Database
err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error
if err != nil {
fmt.Println("db err: (Init) ", err)
}
// Get into testing database
dsn = fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=Asia/Singapore", host, user, pass, dbName, port, sslMode)
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: newLogger,
})
if err != nil {
fmt.Println("db err: (Init) ", err)
}
// Setup UUID
// CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
db.Exec("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";")
DB = db
return DB
}
func DestroyTestingDB(db *gorm.DB) {
var dbName string
db.Raw("SELECT current_database();").Scan(&dbName)
db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName))
}
// Using this function to get a connection, you can create your connection pool here.
func GetDB() *gorm.DB {
return DB
}

View File

@@ -0,0 +1,56 @@
package common
import (
"net/http"
"github.com/go-chi/render"
)
type ErrResponse struct {
Err error `json:"-"` // low-level runtime error
HTTPStatusCode int `json:"-"` // http response status code
StatusText string `json:"status"` // user-level status message
AppCode int64 `json:"code,omitempty"` // application-specific error code
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
}
func (e *ErrResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, e.HTTPStatusCode)
return nil
}
func ErrInvalidRequest(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 400,
StatusText: "Invalid request.",
ErrorText: err.Error(),
}
}
func ErrValidationError(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 422,
StatusText: "Validation error.",
ErrorText: err.Error(),
}
}
func ErrInternalError(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 500,
StatusText: "Invalid request.",
ErrorText: err.Error(),
}
}
func ErrNotFound(err error) render.Renderer {
return &ErrResponse{
HTTPStatusCode: 404,
StatusText: "Resource not found.",
ErrorText: err.Error(),
}
}

View File

@@ -0,0 +1,24 @@
package common
import (
"net/http"
"github.com/go-chi/render"
)
type TextResponse struct {
Status string `json:"status"` // user-level status message
Text string `json:"text"` // application-specific error code
}
func (e *TextResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, http.StatusOK)
return nil
}
func NewGenericTextResponse(status string, text string) render.Renderer {
return &TextResponse{
Status: status,
Text: text,
}
}

View File

@@ -0,0 +1,477 @@
package ktmtrainbot
import (
"context"
"errors"
"fmt"
"log"
"os"
"strconv"
"strings"
"time"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/user"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/devices"
"github.com/go-rod/rod/lib/launcher"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const IMAGE_DIR = "/tmp/screenshots"
const TIMEOUT_MINUTE = 60
func (env *Env) BackgroundJobRunner() {
log.Println("Initialising background job...")
initialiseRodBrowser()
log.Println("Browser initialised...")
// Initialise silent logger
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
logger.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logger.Silent, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
Colorful: true, // Disable color
},
)
tx := env.DB.Session(&gorm.Session{Logger: newLogger})
for {
var jobToDo Booking
err := tx.Model(&jobToDo).
Where("status = ?", "pending").
Preload("User").
First(&jobToDo).Error
// if no jobs pending found
if err != nil {
time.Sleep(time.Second)
continue
} else { // if there's job to do
// Create next run where it's not the past (either from old NextRun or now())
timeNow := time.Now()
if timeNow.Hour() == 0 && timeNow.Minute() == 10 {
err := env.DB.Where(&user.Profile{UserID: jobToDo.UserID}).First(&jobToDo.User.Profile).Error
if err != nil {
log.Println(err)
}
log.Printf("Start doing job: %v", jobToDo.ID)
username := jobToDo.User.Profile.KtmTrainUsername
password := jobToDo.User.Profile.KtmTrainPassword
creditCardType := jobToDo.User.Profile.KtmTrainCreditCardType
creditCard := jobToDo.User.Profile.KtmTrainCreditCard
creditCardCVV := jobToDo.User.Profile.KtmTrainCreditCardCVV
creditCardExpiry := jobToDo.User.Profile.KtmTrainCreditCardExpiry
func() {
defer func() {
if r := recover(); r != nil {
log.Println("Recovering from job panic...")
jobToDo.Status = "pending"
env.DB.Save(&jobToDo)
}
}()
success := env.startBooking(&jobToDo, username, password, creditCardType, creditCard, creditCardCVV, creditCardExpiry)
if success {
fmt.Println("Successfully made a booking.")
jobToDo.Status = "success"
env.DB.Save(jobToDo)
} else {
jobToDo.Status = "pending"
env.DB.Save(&jobToDo)
fmt.Println("Failed to make a booking.")
}
}()
jobToDo.Status = "running"
env.DB.Save(&jobToDo)
log.Printf("Job Started: %v", jobToDo.ID)
}
}
}
}
func initialiseRodBrowser() {
u := launcher.New().
Set("headless").
MustLaunch()
defaultDevice := devices.LaptopWithMDPIScreen
browser := rod.New().ControlURL(u).MustConnect().DefaultDevice(defaultDevice)
page := browser.MustPage("https://www.google.com").MustWindowFullscreen()
page.MustWaitLoad()
browser.MustClose()
// Initialise screenshot directory
if _, err := os.Stat(IMAGE_DIR); errors.Is(err, os.ErrNotExist) {
err := os.Mkdir(IMAGE_DIR, os.ModePerm)
if err != nil {
log.Println(err)
}
}
}
func (env *Env) startBooking(job *Booking, username string, password string, creditCardType string, creditCard string, creditCardCVV string, creditCardExpiry string) bool {
timerCtx, cancelTimer := context.WithTimeout(context.Background(), TIMEOUT_MINUTE*time.Minute)
defer cancelTimer()
headless := os.Getenv("HEADLESS")
var u string
if strings.ToUpper(headless) == "FALSE" {
u = launcher.New().
Set("headless").
Delete("--headless").
MustLaunch()
} else {
u = launcher.New().
Set("headless").
MustLaunch()
}
defaultDevice := devices.LaptopWithMDPIScreen
// defaultDevice.Screen.Vertical.Height = defaultDevice.Screen.Horizontal.Height
// defaultDevice.Screen.Vertical.Width = defaultDevice.Screen.Horizontal.Width
defaultDevice.UserAgent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36"
browser := rod.New().ControlURL(u).MustConnect().DefaultDevice(defaultDevice)
// Defer closing browser
defer browser.MustClose()
postLoginPage := ktmTrainLogin(browser, username, password)
nowTimeStr := time.Now().Format("2006-01-02-15_04_05")
postLoginPage.MustWaitLoad().MustScreenshot(fmt.Sprintf("%s/%s-01-login.png", IMAGE_DIR, nowTimeStr))
// Exits if context cancelled
select {
case <-timerCtx.Done():
browser.MustClose()
return false
default:
}
var page *rod.Page
getBookingSlotCtx, cancelGetBookingSlot := context.WithTimeout(context.Background(), TIMEOUT_MINUTE*time.Minute)
defer cancelGetBookingSlot()
pageChan := make(chan *rod.Page)
onwardDate := job.TravelDate.Format("2 Jan 2006")
timeCode := job.TimeCode
name := job.Name
gender := job.Gender
passport := job.Passport
passportExpiry := job.PassportExpiry.Format("2 Jan 2006")
contact := job.Contact
threadCount := 10
for i := 0; i < threadCount; i++ {
time.Sleep(time.Millisecond * 100)
go func() {
currPage := getBookingSlots(browser, onwardDate)
log.Println("Booking page loaded.")
currPage = selectBookingSlot(getBookingSlotCtx, currPage, timeCode)
log.Println("Booking slot selected.")
select {
case <-getBookingSlotCtx.Done():
return
default:
log.Println("First page loaded.")
}
// Cancelling context
cancelGetBookingSlot()
pageChan <- currPage
}()
}
page = <-pageChan
page.MustActivate()
// Exits if context cancelled
select {
case <-timerCtx.Done():
browser.MustClose()
return false
default:
}
page = fillPassengerDetails(page, name, gender, passport, passportExpiry, contact)
log.Println("Passenger details filled.")
// Exits if context cancelled
select {
case <-timerCtx.Done():
browser.MustClose()
return false
default:
}
page = choosePayment(page)
log.Println("Payment method chosen.")
// Wait 5 seconds for payment gateway to load
time.Sleep(time.Second * 5)
for _, currPage := range browser.MustPages() {
currPage.MustWaitLoad()
var currTitle string
err := rod.Try(func() {
currTitle = currPage.Timeout(100 * time.Millisecond).MustElement("title").MustText()
})
if err != nil {
currTitle = ""
}
if strings.Contains(currTitle, "Payment Acceptance") {
page = currPage
}
}
// Exits if context cancelled
select {
case <-timerCtx.Done():
browser.MustClose()
return false
default:
}
expiryMonth := strings.Split(creditCardExpiry, "/")[0]
expiryYear := strings.Split(creditCardExpiry, "/")[1]
page = makePayment(page, creditCardType, creditCard, expiryMonth, expiryYear, creditCardCVV)
log.Println("Payment made.")
// // Start debug screenshots
// debugScreenshotCtx, cancelDebugScreenshot := context.WithCancel(context.Background())
// go takeDebugScreenshots(debugScreenshotCtx, courtPage)
// // Defer done with debug screenshot
// defer cancelDebugScreenshot()
time.Sleep(600 * time.Second)
_ = page
browser.MustClose()
return true
}
func ktmTrainLogin(browser *rod.Browser, username string, password string) *rod.Page {
page := browser.MustPage("https://online.ktmb.com.my/Account/Login")
page.MustElement("#Email").MustInput(username)
page.MustElement("#Password").MustInput(password)
page.MustElement("#LoginButton").MustClick()
return page
}
func getBookingSlots(browser *rod.Browser, onwardDate string) *rod.Page {
page := browser.MustPage("https://shuttleonline.ktmb.com.my/Home/Shuttle")
page.MustWaitLoad()
// Dismiss system maintenance warning
bodyText := page.MustElement("body").MustText()
containsCheckStr := "System maintenance scheduled at 23:00 to 00:15"
if strings.Contains(bodyText, containsCheckStr) {
page.MustEval(`() => document.querySelector("#validationSummaryModal > div > div > div.modal-body > div > div.text-center > button").click()`)
}
passengerCount := 1
passengerCountStr := strconv.Itoa(passengerCount)
requestVerificationToken := page.MustElement("#theForm > input[name=__RequestVerificationToken]").MustAttribute("value")
// Get JB Sentral Station Info
jBSentralData := page.MustElement("#FromStationData").MustAttribute("value")
jBSentralID := page.MustElement("#FromStationId").MustAttribute("value")
// Get Woodlands Station Info
woodlandsData := page.MustElement("#ToStationData").MustAttribute("value")
woodlandsID := page.MustElement("#ToStationId").MustAttribute("value")
sensitiveCustomForm(page, *woodlandsData, *jBSentralData, *woodlandsID, *jBSentralID, onwardDate, passengerCountStr, *requestVerificationToken)
page.MustWaitLoad()
return page
}
func selectBookingSlot(ctx context.Context, page *rod.Page, timeCode string) *rod.Page {
time.Sleep(1 * time.Second)
// Initial closing of maintenance modal
bodyText := page.MustElement("body").MustText()
if strings.Contains(bodyText, "System maintenance scheduled at 23:00 to 00:15 (UTC+8)") {
closeModalButton := page.MustElement("#popupModalCloseButton")
closeModalButton.Eval(`this.click()`)
time.Sleep(1000 * time.Millisecond)
}
// Start probing
completed := false
for !completed {
page.MustWaitLoad()
departTripsTable := page.MustElement(".depart-trips")
departRows := departTripsTable.MustElements("tr")
var rowElement *rod.Element
for _, row := range departRows {
timeCodeElement := row.MustAttribute("data-hourminute")
if *timeCodeElement == timeCode {
rowElement = row
}
}
// Checks for context before clicking
select {
case <-ctx.Done():
return page
default:
}
selectButtonElement := rowElement.MustElement("a")
selectButtonElement.Eval(`this.click()`)
time.Sleep(1000 * time.Millisecond)
page.MustWaitLoad()
// Check before exiting
bodyText := page.MustElement("body").MustText()
if strings.Contains(bodyText, "System maintenance scheduled at 23:00 to 00:15 (UTC+8).") {
completed = false
closeModalButton := page.MustElement("#popupModalCloseButton")
closeModalButton.Eval(`this.click()`)
time.Sleep(1000 * time.Millisecond)
} else {
log.Println("Completed probing")
completed = true
}
}
// Checks for context before clicking
select {
case <-ctx.Done():
return page
default:
}
proceedButton := page.MustElement(".proceed-btn")
proceedButton.Eval(`this.click()`)
return page
}
func fillPassengerDetails(page *rod.Page, name string, gender string, passport string, passportExpiry string, contact string) *rod.Page {
ticketType := "DEWASA/ADULT"
nameElement := page.MustElement(".FullName")
nameElement.MustInput(name)
passportElement := page.MustElement(".PassportNo")
passportElement.MustInput(passport)
passportExpiryElement := page.MustElement("#Passengers_0__PassportExpiryDate")
passportExpiryElement.MustInput(passportExpiry)
contactElement := page.MustElement(".ContactNo")
contactElement.MustInput(contact)
if gender == "M" {
maleElement := page.MustElement("#Passengers_0__GenderMale")
maleElement.Eval(`this.click()`)
} else {
femaleElement := page.MustElement("#Passengers_0__GenderFemale")
femaleElement.Eval(`this.click()`)
}
ticketTypeElement := page.MustElement("#Passengers_0__TicketTypeId")
ticketTypeElement.MustSelect(ticketType)
paymentButton := page.MustElement("#btnConfirmPayment")
paymentButton.Eval(`this.click()`)
page.MustWaitLoad()
return page
}
func choosePayment(page *rod.Page) *rod.Page {
creditCardButton := page.MustElement(".btn-public-bank")
creditCardButton.Eval(`this.click()`)
page.MustWaitLoad()
paymentGatewayButton := page.MustElement("#PaymentGateway")
paymentGatewayButton.Eval(`this.click()`)
page.MustWaitLoad()
return page
}
func makePayment(page *rod.Page, cardType string, creditCard string, expiryMonth string, expiryYear string, creditCardCVV string) *rod.Page {
if cardType == "Visa" {
visaRadio := page.MustElement("#card_type_001")
visaRadio.Eval(`this.click()`)
} else if cardType == "Mastercard" {
masterRadio := page.MustElement("#card_type_002")
masterRadio.Eval(`this.click()`)
}
creditCardElement := page.MustElement("#card_number")
creditCardElement.MustInput(creditCard)
expiryMonthElement := page.MustElement("#card_expiry_month")
expiryMonthElement.MustSelect(expiryMonth)
expiryYearElement := page.MustElement("#card_expiry_year")
expiryYearElement.MustSelect(expiryYear)
creditCardCVVElement := page.MustElement("#card_cvn")
creditCardCVVElement.MustInput(creditCardCVV)
log.Println("Before payment")
time.Sleep(10 * time.Millisecond)
payButton := page.MustElement(".pay_button")
payButton.Eval(`this.click()`)
return page
}
func sensitiveCustomForm(
page *rod.Page,
fromStationData string,
toStationData string,
fromStationID string,
toStationID string,
onwardDate string,
passengerCount string,
csrf string,
) *rod.Page {
defer func() {
if r := recover(); r != nil {
log.Println("Recovered in sensitiveCustomForm", r)
}
}()
formHTML := fmt.Sprintf(`
<form action="https://shuttleonline.ktmb.com.my/ShuttleTrip" method="POST">
<input type="hidden" name="FromStationData" value="%s" />
<input type="hidden" name="ToStationData" value="%s" />
<input type="hidden" name="FromStationId" value="%s" />
<input type="hidden" name="ToStationId" value="%s" />
<input type="hidden" name="OnwardDate" value="%s" />
<input type="hidden" name="ReturnDate" value="" />
<input type="hidden" name="PassengerCount" value="%s" />
<input type="hidden" name="__RequestVerificationToken" value="%s" />
<input type="submit" id="presshere" value="Submit request" />
</form>
UniqueStringHere
`, fromStationData, toStationData, fromStationID, toStationID, onwardDate, passengerCount, csrf)
page.MustElement("body").MustEval(fmt.Sprintf("() => this.innerHTML = `%s`", formHTML))
// page.MustElement("#presshere").MustClick()
page.MustEval(`() => document.querySelector("#presshere").click()`)
return page
}

View File

@@ -0,0 +1,82 @@
package ktmtrainbot
import (
"errors"
"log"
"time"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/user"
"github.com/google/uuid"
)
func (env *Env) createBooking(
user *user.User,
travelDate time.Time,
timeCode string,
name string,
gender string,
passport string,
passportExpiry time.Time,
contact string,
) (*Booking, error) {
var newBooking Booking
newBooking.User = *user
newBooking.TravelDate = travelDate
newBooking.TimeCode = timeCode
newBooking.Name = name
newBooking.Gender = gender
newBooking.Passport = passport
newBooking.PassportExpiry = passportExpiry
newBooking.Status = "pending"
err := env.DB.Create(&newBooking).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed create new booking")
}
return &newBooking, nil
}
func (env *Env) getAllBooking(user *user.User) ([]Booking, error) {
var booking []Booking
err := env.DB.Where("user_id = ?", user.ID).Order("court_weekday asc").Find(&booking).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed get booking")
}
return booking, nil
}
func (env *Env) deleteBooking(
user *user.User,
bookingIDStr string,
) (*Booking, error) {
var newBooking Booking
bookingID, err := uuid.Parse(bookingIDStr)
if err != nil {
log.Println(err)
return nil, errors.New("invalid uuid")
}
err = env.DB.Where(&Booking{ID: bookingID}).Where("user_id = ?", user.ID).First(&newBooking).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed retrieve booking")
}
err = env.DB.Delete(&newBooking).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed to delete booking")
}
return &newBooking, nil
}

View File

@@ -0,0 +1,92 @@
package ktmtrainbot
import (
"net/http"
"time"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/user"
"github.com/google/uuid"
"gorm.io/gorm"
)
type Booking struct {
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
User user.User
UserID uuid.UUID
TravelDate time.Time // Only date matters
TimeCode string // e.g. 1400
Name string
Gender string // M/F
Passport string
PassportExpiry time.Time // Only date matters
Contact string // +6512345678
Status string // "success", "error", "pending", "running"
}
type BookingCreateRequest struct {
TravelDate time.Time `json:"travelDate" validate:"required"`
TimeCode string `json:"timeCode" validate:"required len=4"`
Name string `json:"name" validate:"required"`
Gender string `json:"gender" validate:"required len=1 containsany=MF"`
Passport string `json:"passport" validate:"required"`
PassportExpiry time.Time `json:"passportExpiry" validate:"required"`
Contact string `json:"contact" validate:"required e164"`
}
type BookingResponse struct {
ID uuid.UUID `json:"id"`
TravelDate time.Time `json:"travelDate"`
TimeCode string `json:"timeCode"`
Name string `json:"name"`
Gender string `json:"gender"`
Passport string `json:"passport"`
PassportExpiry time.Time `json:"passportExpiry"`
Contact string `json:"contact"`
Status string `json:"status"`
}
type BookingListResponse []BookingResponse
func (res *BookingResponse) Render(w http.ResponseWriter, r *http.Request) error {
// Pre-processing before a response is marshalled and sent across the wire
return nil
}
func (res BookingListResponse) Render(w http.ResponseWriter, r *http.Request) error {
// Pre-processing before a response is marshalled and sent across the wire
// if res == nil {
// var empty []BookingResponse
// res = empty
// }
return nil
}
func (env *Env) NewBookingResponse(model *Booking) *BookingResponse {
res := &BookingResponse{
ID: model.ID,
TravelDate: model.TravelDate,
TimeCode: model.TimeCode,
Name: model.Name,
Gender: model.Gender,
Passport: model.Passport,
PassportExpiry: model.PassportExpiry,
Contact: model.Contact,
Status: model.Status,
}
return res
}
func (env *Env) NewBookingListResponse(model []Booking) BookingListResponse {
var res []BookingResponse
for _, item := range model {
curr := env.NewBookingResponse(&item)
res = append(res, *curr)
}
return res
}

View File

@@ -0,0 +1,124 @@
package ktmtrainbot
import (
"errors"
"net/http"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/user"
"github.com/go-chi/chi"
"github.com/go-chi/render"
"github.com/go-playground/validator/v10"
)
// Get All Bookings
// @Summary Get All Booking
// @Description Description
// @Tags ktmtrainbot Booking
// @Accept json
// @Produce json
// @Success 200 {object} []BookingResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/ktmtrainbot/booking [get]
func (env *Env) getBookingRoute(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
currUser, ok := ctx.Value(user.UserContextKey).(*user.User)
if !ok {
err := errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
_ = currUser
booking, err := env.getAllBooking(currUser)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
render.Render(w, r, env.NewBookingListResponse(booking))
}
// Create New Booking
// @Summary Create New Booking
// @Description Description
// @Tags ktmtrainbot Booking
// @Accept json
// @Produce json
// @Param user body BookingCreateRequest true "Booking Create Request"
// @Success 200 {object} BookingResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/ktmtrainbot/booking [post]
func (env *Env) createBookingRoute(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
currUser, ok := ctx.Value(user.UserContextKey).(*user.User)
if !ok {
err := errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
_ = currUser
data := &BookingCreateRequest{}
err := render.DecodeJSON(r.Body, data)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
err = validator.New().Struct(data)
if err != nil {
render.Render(w, r, common.ErrValidationError(err))
return
}
booking, err := env.createBooking(
currUser,
data.TravelDate,
data.TimeCode,
data.Name,
data.Gender,
data.Passport,
data.PassportExpiry,
data.Contact,
)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
render.Render(w, r, env.NewBookingResponse(booking))
}
// Delete booking
// @Summary Delete booking
// @Description Description
// @Tags ktmtrainbot Booking
// @Accept json
// @Produce json
// @Param bookingID path string true "Booking ID"
// @Success 200 {object} BookingResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/ktmtrainbot/booking/{bookingID} [delete]
func (env *Env) deleteBookingRoute(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
currUser, ok := ctx.Value(user.UserContextKey).(*user.User)
if !ok {
err := errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
_ = currUser
bookingID := chi.URLParam(r, "bookingID")
booking, err := env.deleteBooking(currUser, bookingID)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
render.Render(w, r, env.NewBookingResponse(booking))
}

View File

@@ -0,0 +1,25 @@
package ktmtrainbot
import (
"net/http"
"time"
"github.com/go-chi/render"
)
// Get Current Server Time
// @Summary Get current server time
// @Description Description
// @Tags Info
// @Produce json
// @Success 200 {object} ServerTimeResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/ktmtrainbot/current-time [get]
func (env *Env) getCurrentTime(w http.ResponseWriter, r *http.Request) {
timeNow := time.Now()
var res ServerTimeResponse
res.ServerLocalTime = timeNow.In(time.Local).Format(time.RFC1123Z)
render.Render(w, r, &res)
}

View File

@@ -0,0 +1,31 @@
package ktmtrainbot
import (
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/user"
"github.com/go-chi/chi"
"gorm.io/gorm"
)
type Env struct {
DB *gorm.DB
}
func KTMTrainBotRoutes(db *gorm.DB) chi.Router {
var env Env
env.DB = db
// Start running job
go env.BackgroundJobRunner()
userEnv := user.NewUserEnv(db)
r := chi.NewRouter()
checkLoggedInUserGroup := r.Group(nil)
r.Get("/current-time", env.getCurrentTime)
checkLoggedInUserGroup.Use(userEnv.CheckUserMiddleware)
checkLoggedInUserGroup.Get("/booking", env.getBookingRoute)
checkLoggedInUserGroup.Post("/booking", env.createBookingRoute)
checkLoggedInUserGroup.Delete("/booking/{bookingID}", env.deleteBookingRoute)
return r
}

View File

@@ -0,0 +1,12 @@
package ktmtrainbot
import "net/http"
type ServerTimeResponse struct {
ServerLocalTime string `json:"serverLocalTime"`
}
func (serverTimeResponse *ServerTimeResponse) Render(w http.ResponseWriter, r *http.Request) error {
// Pre-processing before a response is marshalled and sent across the wire
return nil
}

View File

@@ -0,0 +1,50 @@
package user
import (
"log"
"os"
"strings"
"github.com/go-chi/chi"
"gorm.io/gorm"
)
type Env struct {
DB *gorm.DB
CookieString string
}
func UserRoutes(db *gorm.DB) chi.Router {
var env Env
env.DB = db
env.CookieString = os.Getenv("COOKIE_STRING")
if env.CookieString == "" {
env.CookieString = "cookie_string"
}
r := chi.NewRouter()
allowRegistration := os.Getenv("ALLOW_REGISTRATION")
if strings.ToUpper(allowRegistration) == "TRUE" || allowRegistration == "1" {
log.Println("Registration enabled.")
r.Post("/register", env.registerRouteHandler)
}
r.Post("/login", env.loginRouteHandler)
r.Post("/logout", env.logoutRouteHandler)
checkLoggedInUserGroup := r.Group(nil)
checkLoggedInUserGroup.Use(env.CheckUserMiddleware)
checkLoggedInUserGroup.Get("/me", env.meRouteHandler)
checkLoggedInUserGroup.Put("/profile", env.setProfileRouteHandler)
return r
}
func NewUserEnv(db *gorm.DB) *Env {
var env Env
env.DB = db
env.CookieString = os.Getenv("COOKIE_STRING")
if env.CookieString == "" {
env.CookieString = "cooking_string"
}
return &env
}

View File

@@ -0,0 +1,146 @@
package user
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
)
func TestUserRoutesRegistration(t *testing.T) {
testCases := []struct {
allowRegistrationFlag bool
allowRegistrationFlagSet bool
statusCode int
}{
{
allowRegistrationFlag: true,
allowRegistrationFlagSet: true,
statusCode: 201,
},
{
allowRegistrationFlag: false,
allowRegistrationFlagSet: true,
statusCode: 404,
},
{
allowRegistrationFlag: false,
allowRegistrationFlagSet: false,
statusCode: 404,
},
}
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
for _, currTestCase := range testCases {
if currTestCase.allowRegistrationFlagSet {
if currTestCase.allowRegistrationFlag {
t.Setenv("ALLOW_REGISTRATION", "true")
} else {
t.Setenv("ALLOW_REGISTRATION", "false")
}
} else {
t.Setenv("ALLOW_REGISTRATION", "")
}
router := UserRoutes(db)
rr := httptest.NewRecorder()
reqBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBodyBytes, err := json.Marshal(reqBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBodyBytes)
req, err := http.NewRequest("POST", "/register", reqBodyReader)
if err != nil {
t.Errorf("Error creating a new request: %v", err)
}
router.ServeHTTP(rr, req)
currStatusCode := rr.Result().StatusCode
if currStatusCode != currTestCase.statusCode {
t.Errorf("Wrong status code: Expected - %d. Got - %d", currTestCase.statusCode, rr.Result().StatusCode)
}
}
}
func TestUserRoutesLoggedIn(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
router := UserRoutes(db)
rr := httptest.NewRecorder()
// Testing /me
req, err := http.NewRequest("GET", "/me", nil)
if err != nil {
t.Errorf("Error creating a new request: %v", err)
}
router.ServeHTTP(rr, req)
if rr.Result().StatusCode != 500 {
t.Errorf("Wrong status code: Expected - %d. Got - %d", 500, rr.Result().StatusCode)
}
responseBytes, err := ioutil.ReadAll(rr.Result().Body)
if err != nil {
t.Errorf("Error reading response body: %v", err)
}
var results map[string]any
err = json.Unmarshal(responseBytes, &results)
if err != nil {
t.Errorf("Error decoding response body: %v", err)
}
if results["error"] != "user not logged in" {
t.Errorf("Error for %s not correct", "/me")
}
// Testing /profile
rr = httptest.NewRecorder()
emptyJSON := bytes.NewReader([]byte("{}"))
req, err = http.NewRequest("PUT", "/profile", emptyJSON)
if err != nil {
t.Errorf("Error creating a new request: %v", err)
}
router.ServeHTTP(rr, req)
if rr.Result().StatusCode != 500 {
t.Errorf("Wrong status code: Expected - %d. Got - %d", 500, rr.Result().StatusCode)
}
responseBytes, err = ioutil.ReadAll(rr.Result().Body)
if err != nil {
t.Errorf("Error reading response body: %v", err)
}
err = json.Unmarshal(responseBytes, &results)
if err != nil {
t.Errorf("Error decoding response body: %v", err)
}
if results["error"] != "user not logged in" {
t.Errorf("Error for %s not correct", "/me")
}
}

View File

@@ -0,0 +1,29 @@
package user
import (
"log"
"gorm.io/gorm/clause"
)
func (env *Env) setProfile(currUser *User, ktmTrainUsername string, ktmTrainPassword string, ktmTrainCreditCardType string, ktmTrainCreditCard string, ktmTrainCreditCardExpiry string, ktmTrainCreditCardCVV string) (*User, error) {
profile := &Profile{
UserID: currUser.ID,
KtmTrainUsername: ktmTrainUsername,
KtmTrainPassword: ktmTrainPassword,
KtmTrainCreditCardType: ktmTrainCreditCardType,
KtmTrainCreditCard: ktmTrainCreditCard,
KtmTrainCreditCardExpiry: ktmTrainCreditCardExpiry,
KtmTrainCreditCardCVV: ktmTrainCreditCardCVV,
}
if err := env.DB.Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(profile).Error; err != nil {
log.Println("Error creating profile", err)
return nil, err
}
currUser.Profile = *profile
return currUser, nil
}

View File

@@ -0,0 +1,52 @@
package user
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type Profile struct {
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
UserID uuid.UUID `gorm:"index"`
KtmTrainUsername string
KtmTrainPassword string
KtmTrainCreditCardType string // Visa/Mastercard
KtmTrainCreditCard string
KtmTrainCreditCardExpiry string
KtmTrainCreditCardCVV string
}
type ProfileRequest struct {
KtmTrainUsername string `json:"ktmTrainUsername"`
KtmTrainPassword string `json:"ktmTrainPassword"`
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
}
type ProfileResponse struct {
KtmTrainUsername string `json:"ktmTrainUsername"`
KtmTrainPassword string `json:"ktmTrainPassword"`
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
}
func (env *Env) NewProfileResponse(model *Profile) *ProfileResponse {
res := &ProfileResponse{
KtmTrainUsername: model.KtmTrainUsername,
KtmTrainPassword: model.KtmTrainPassword,
KtmTrainCreditCardType: model.KtmTrainCreditCardType,
KtmTrainCreditCard: model.KtmTrainCreditCard,
KtmTrainCreditCardExpiry: model.KtmTrainCreditCardExpiry,
KtmTrainCreditCardCVV: model.KtmTrainCreditCardCVV,
}
return res
}

View File

@@ -0,0 +1,45 @@
package user
import (
"errors"
"net/http"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
"github.com/go-chi/render"
)
// Set current user profile
// @Summary For setting current user profile
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Param user body ProfileRequest true "User registration info"
// @Success 200 {object} UserResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/profile [put]
func (env *Env) setProfileRouteHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := &ProfileRequest{}
err := render.DecodeJSON(r.Body, data)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
currUser, ok := ctx.Value(UserContextKey).(*User)
if !ok {
err := errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
currUser, err = env.setProfile(currUser, data.KtmTrainUsername, data.KtmTrainPassword, data.KtmTrainCreditCardType, data.KtmTrainCreditCard, data.KtmTrainCreditCardExpiry, data.KtmTrainCreditCardCVV)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
render.Render(w, r, env.NewUserResponse(currUser))
}

View File

@@ -0,0 +1,174 @@
package user
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http/httptest"
"testing"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
)
func TestSetProfile(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
t.Setenv("COOKIE_STRING", "supercustomcookie")
router := UserRoutes(db)
testCases := []struct {
ktmTrainUsername string
ktmTrainPassword string
ktmTrainCreditCardType string
ktmTrainCreditCard string
ktmTrainCreditCardExpiry string
ktmTrainCreditCardCVV string
statusCode int
cookieEnabled bool
}{
{
ktmTrainUsername: "test",
ktmTrainPassword: "test",
ktmTrainCreditCardType: "Visa",
ktmTrainCreditCard: "1234567890123456",
ktmTrainCreditCardExpiry: "05/2025",
ktmTrainCreditCardCVV: "123",
statusCode: 200,
cookieEnabled: true,
},
{
ktmTrainUsername: "",
ktmTrainPassword: "",
ktmTrainCreditCardType: "",
ktmTrainCreditCard: "",
ktmTrainCreditCardExpiry: "",
ktmTrainCreditCardCVV: "",
statusCode: 200,
cookieEnabled: true,
},
{
ktmTrainUsername: "test",
ktmTrainPassword: "test",
ktmTrainCreditCardType: "Visa",
ktmTrainCreditCard: "1234567890123456",
ktmTrainCreditCardExpiry: "05/2025",
ktmTrainCreditCardCVV: "123",
statusCode: 500,
cookieEnabled: false,
},
}
// Register user
rr := httptest.NewRecorder()
currBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err := json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBody)
req := httptest.NewRequest("POST", "/register", reqBodyReader)
router.ServeHTTP(rr, req)
// Check registration results
if rr.Code != 201 {
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
}
// Login to get cookie
rr = httptest.NewRecorder()
currBodyLogin := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err = json.Marshal(currBodyLogin)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader = bytes.NewReader(reqBody)
req = httptest.NewRequest("POST", "/login", reqBodyReader)
router.ServeHTTP(rr, req)
for _, currentTestCase := range testCases {
// Start checking Profile
rrProfile := httptest.NewRecorder()
currBody := struct {
KtmTrainUsername string `json:"ktmTrainUsername"`
KtmTrainPassword string `json:"ktmTrainPassword"`
KtmTrainCreditCardType string `json:"ktmTrainCreditCardType"`
KtmTrainCreditCard string `json:"ktmTrainCreditCard"`
KtmTrainCreditCardExpiry string `json:"ktmTrainCreditCardExpiry"`
KtmTrainCreditCardCVV string `json:"ktmTrainCreditCardCVV"`
}{
KtmTrainUsername: currentTestCase.ktmTrainUsername,
KtmTrainPassword: currentTestCase.ktmTrainPassword,
KtmTrainCreditCardType: currentTestCase.ktmTrainCreditCardType,
KtmTrainCreditCard: currentTestCase.ktmTrainCreditCard,
KtmTrainCreditCardExpiry: currentTestCase.ktmTrainCreditCardExpiry,
KtmTrainCreditCardCVV: currentTestCase.ktmTrainCreditCardCVV,
}
reqBody, err = json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader = bytes.NewReader(reqBody)
req = httptest.NewRequest("PUT", "/profile", reqBodyReader)
if currentTestCase.cookieEnabled {
req.AddCookie(rr.Result().Cookies()[0])
}
router.ServeHTTP(rrProfile, req)
// Check Profile results
if rrProfile.Code != currentTestCase.statusCode {
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrProfile.Code)
}
if currentTestCase.statusCode == 200 {
var resultsObj map[string]any
resBodyBytes, err := ioutil.ReadAll(rrProfile.Body)
if err != nil {
t.Errorf("Error reading response body: %v", err)
}
err = json.Unmarshal(resBodyBytes, &resultsObj)
if err != nil {
t.Errorf("Error unmarshalling response body: %v", err)
}
resultsProfile := resultsObj["profile"].(map[string]any)
if resultsObj["username"] != "testusername" {
t.Errorf("Expected username %s, got %s", "testusername", resultsObj["username"])
}
if resultsProfile["ktmTrainUsername"] != currentTestCase.ktmTrainUsername {
t.Errorf("Expected ktmTrainUsername %s, got %s", currentTestCase.ktmTrainUsername, resultsProfile["ktmTrainUsername"])
}
if resultsProfile["ktmTrainPassword"] != currentTestCase.ktmTrainPassword {
t.Errorf("Expected ktmTrainPassword %s, got %s", currentTestCase.ktmTrainPassword, resultsProfile["ktmTrainPassword"])
}
if resultsProfile["ktmTrainCreditCardType"] != currentTestCase.ktmTrainCreditCardType {
t.Errorf("Expected ktmTrainCreditCardType %s, got %s", currentTestCase.ktmTrainCreditCardType, resultsProfile["ktmTrainCreditCardType"])
}
if resultsProfile["ktmTrainCreditCard"] != currentTestCase.ktmTrainCreditCard {
t.Errorf("Expected ktmTrainCreditCard %s, got %s", currentTestCase.ktmTrainCreditCard, resultsProfile["ktmTrainCreditCard"])
}
if resultsProfile["ktmTrainCreditCardExpiry"] != currentTestCase.ktmTrainCreditCardExpiry {
t.Errorf("Expected ktmTrainCreditCardExpiry %s, got %s", currentTestCase.ktmTrainCreditCardExpiry, resultsProfile["ktmTrainCreditCardExpiry"])
}
if resultsProfile["ktmTrainCreditCardCVV"] != currentTestCase.ktmTrainCreditCardCVV {
t.Errorf("Expected ktmTrainCreditCardCVV %s, got %s", currentTestCase.ktmTrainCreditCardCVV, resultsProfile["ktmTrainCreditCardCVV"])
}
}
}
}

View File

@@ -0,0 +1,56 @@
package user
import (
"errors"
"log"
"github.com/google/uuid"
)
func (env *Env) createSession(user *User) (string, error) {
var newSession Session
newSession.UserID = user.ID
newSession.SessionToken = uuid.New().String()
err := env.DB.Create(&newSession).Error
if err != nil {
log.Println(err)
return "", errors.New("failed write to database")
}
return newSession.SessionToken, nil
}
func (env *Env) getUserFromSessionToken(sessionToken string) (*User, error) {
var currUser User
err := env.DB.Table("sessions").Select("users.*").Joins("left join users on users.id = sessions.user_id").Where("sessions.session_token = ?", sessionToken).First(&currUser).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed get user")
}
err = env.DB.Preload("Profile").Where(&currUser).First(&currUser).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed get user")
}
return &currUser, nil
}
func (env *Env) logout(sessionToken string) error {
var currSession Session
err := env.DB.Where(&Session{SessionToken: sessionToken}).First(&currSession).Error
if err != nil {
log.Println(err)
return errors.New("failed get session")
}
err = env.DB.Delete(&currSession).Error
if err != nil {
log.Println(err)
return errors.New("failed to logout")
}
return nil
}

View File

@@ -0,0 +1,29 @@
package user
import (
"context"
"errors"
"net/http"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
"github.com/go-chi/render"
)
func (env *Env) CheckUserMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(env.CookieString)
if err != nil {
err = errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
user, err := env.getUserFromSessionToken(cookie.Value)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
ctx := context.WithValue(r.Context(), UserContextKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@@ -0,0 +1,7 @@
package user
const (
UserContextKey ContextKey = "user"
)
type ContextKey string

View File

@@ -0,0 +1,17 @@
package user
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type Session struct {
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
SessionToken string
UserID uuid.UUID
}

View File

@@ -0,0 +1,71 @@
package user
import (
"errors"
"log"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
const (
BCRYPTCOST = 12
)
func (env *Env) createUser(username string, password string) (*User, error) {
var createdUser User
passwordByte := []byte(password)
createdHashBytes, err := bcrypt.GenerateFromPassword(passwordByte, BCRYPTCOST)
if err != nil {
return nil, errors.New("failed to generate bcrypt")
}
createdUser.Username = username
createdUser.PasswordBcrypt = string(createdHashBytes)
return &createdUser, nil
}
func (env *Env) registerUser(username string, password string) (*User, error) {
newUser, err := env.createUser(username, password)
if err != nil {
log.Println(err)
return nil, errors.New("failed to register user")
}
// Check existing username
var checkUser User
env.DB.Where(&User{Username: username}).First(&checkUser)
if checkUser.ID != uuid.Nil {
log.Println(err)
return nil, errors.New("user already exists")
}
err = env.DB.Create(newUser).Error
if err != nil {
log.Println(err)
return nil, errors.New("failed write to database")
}
return newUser, nil
}
func (env *Env) checkLogin(username string, password string) (*User, error) {
var currUser User
env.DB.Preload("Profile").Where(&User{Username: username}).First(&currUser)
// Prevent username enum by parsing password
if currUser.ID == uuid.Nil {
bcrypt.GenerateFromPassword([]byte{}, BCRYPTCOST)
return nil, errors.New("invalid username or password")
}
err := bcrypt.CompareHashAndPassword([]byte(currUser.PasswordBcrypt), []byte(password))
if err != nil {
return nil, errors.New("invalid username or password")
} else {
return &currUser, nil
}
}

View File

@@ -0,0 +1,52 @@
package user
import (
"net/http"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type User struct {
ID uuid.UUID `gorm:"primaryKey;type:uuid;default:uuid_generate_v4()"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
Username string
PasswordBcrypt string
Profile Profile
Sessions []Session
}
type UserResponse struct {
ID uuid.UUID `json:"id"`
Username string `json:"username"`
Profile *ProfileResponse `json:"profile"`
}
type UserRegisterRequest struct {
Username string `json:"username" validate:"required,min=2,max=100"`
Password string `json:"password" validate:"required,min=6,max=100"`
}
type UserLoginRequest struct {
Username string `json:"username" validate:"required,min=2,max=100"`
Password string `json:"password" validate:"required"`
}
func (userResponse *UserResponse) Render(w http.ResponseWriter, r *http.Request) error {
// Pre-processing before a response is marshalled and sent across the wire
return nil
}
func (env *Env) NewUserResponse(user *User) *UserResponse {
profileResponse := env.NewProfileResponse(&user.Profile)
userResponse := &UserResponse{
ID: user.ID,
Username: user.Username,
Profile: profileResponse,
}
return userResponse
}

View File

@@ -0,0 +1,137 @@
package user
import (
"errors"
"net/http"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
"github.com/go-chi/render"
"github.com/go-playground/validator/v10"
)
// User Register
// @Summary For user registration
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Param user body UserRegisterRequest true "User registration info"
// @Success 200 {object} UserResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/register [post]
func (env *Env) registerRouteHandler(w http.ResponseWriter, r *http.Request) {
data := &UserRegisterRequest{}
err := render.DecodeJSON(r.Body, data)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
err = validator.New().Struct(data)
if err != nil {
render.Render(w, r, common.ErrValidationError(err))
return
}
createdUser, err := env.registerUser(data.Username, data.Password)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
render.Status(r, http.StatusCreated)
render.Render(w, r, env.NewUserResponse(createdUser))
}
// Login
// @Summary For user login
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Param user body UserLoginRequest true "User Login info"
// @Success 200 {object} UserResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/login [post]
func (env *Env) loginRouteHandler(w http.ResponseWriter, r *http.Request) {
data := &UserLoginRequest{}
err := render.DecodeJSON(r.Body, data)
if err != nil {
render.Render(w, r, common.ErrInvalidRequest(err))
return
}
err = validator.New().Struct(data)
if err != nil {
render.Render(w, r, common.ErrValidationError(err))
return
}
loginUser, err := env.checkLogin(data.Username, data.Password)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
sessionToken, err := env.createSession(loginUser)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
loginCookie := http.Cookie{
Name: env.CookieString,
Value: sessionToken,
MaxAge: 7776000,
Path: "/",
}
http.SetCookie(w, &loginCookie)
render.Render(w, r, env.NewUserResponse(loginUser))
}
// Logout
// @Summary For user logout
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Success 200 {object} common.TextResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/logout [post]
func (env *Env) logoutRouteHandler(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(env.CookieString)
if err != nil {
err = errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
err = env.logout(cookie.Value)
if err != nil {
render.Render(w, r, common.ErrInternalError(err))
return
}
render.Render(w, r, common.NewGenericTextResponse("Ok", "Successfully logged out"))
}
// Check current user
// @Summary Returns current logged in user
// @Description Description
// @Tags User
// @Accept json
// @Produce json
// @Success 200 {object} UserResponse
// @Failure 400 {object} common.ErrResponse
// @Router /api/v1/user/me [get]
func (env *Env) meRouteHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
currUser, ok := ctx.Value(UserContextKey).(*User)
if !ok {
err := errors.New("user not logged in")
render.Render(w, r, common.ErrInternalError(err))
return
}
render.Render(w, r, env.NewUserResponse(currUser))
}

View File

@@ -0,0 +1,270 @@
package user
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http/httptest"
"testing"
"git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common"
)
func TestRegistration(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
router := UserRoutes(db)
testCases := []struct {
username string
password string
statusCode int
}{
{
username: "testusername",
password: "testpassword",
statusCode: 201,
},
{
username: "",
password: "testpassword",
statusCode: 422,
},
{
username: "testusername",
password: "",
statusCode: 422,
},
{
username: "t",
password: "testpassword",
statusCode: 422,
},
{
username: "testusername",
password: "test",
statusCode: 422,
},
}
for _, currentTestCase := range testCases {
rr := httptest.NewRecorder()
currBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: currentTestCase.username,
Password: currentTestCase.password,
}
reqBody, err := json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBody)
req := httptest.NewRequest("POST", "/register", reqBodyReader)
router.ServeHTTP(rr, req)
// Check results
if rr.Code != currentTestCase.statusCode {
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rr.Code)
}
}
}
func TestLogin(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
t.Setenv("COOKIE_STRING", "supercustomcookie")
router := UserRoutes(db)
testCases := []struct {
username string
password string
statusCode int
}{
{
username: "testusername",
password: "testpassword",
statusCode: 200,
},
{
username: "",
password: "testpassword",
statusCode: 422,
},
{
username: "testusername",
password: "",
statusCode: 422,
},
{
username: "t",
password: "testpassword",
statusCode: 422,
},
{
username: "testusername",
password: "test",
statusCode: 500,
},
{
username: "testusername",
password: "",
statusCode: 422,
},
}
// Register user
rr := httptest.NewRecorder()
currBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err := json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBody)
req := httptest.NewRequest("POST", "/register", reqBodyReader)
router.ServeHTTP(rr, req)
// Check registration results
if rr.Code != 201 {
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
}
for _, currentTestCase := range testCases {
// Start checking login
rrLogin := httptest.NewRecorder()
currBody = struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: currentTestCase.username,
Password: currentTestCase.password,
}
reqBody, err = json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader = bytes.NewReader(reqBody)
req = httptest.NewRequest("POST", "/login", reqBodyReader)
router.ServeHTTP(rrLogin, req)
// Check login results
if rrLogin.Code != currentTestCase.statusCode {
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrLogin.Code)
}
if rrLogin.Code == 200 {
if rrLogin.Header().Get("Set-Cookie") == "" {
t.Errorf("Expected a cookie to be set, but it wasn't")
}
}
}
}
func TestGetMe(t *testing.T) {
db := common.TestDBInit()
defer common.DestroyTestingDB(db)
db.AutoMigrate(&User{})
db.AutoMigrate(&Profile{})
db.AutoMigrate(&Session{})
t.Setenv("ALLOW_REGISTRATION", "true")
t.Setenv("COOKIE_STRING", "supercustomcookie")
router := UserRoutes(db)
testCases := []struct {
cookieEnabled bool
statusCode int
}{
{
cookieEnabled: true,
statusCode: 200,
},
{
cookieEnabled: false,
statusCode: 500,
},
}
// Register user
rr := httptest.NewRecorder()
currBody := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err := json.Marshal(currBody)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader := bytes.NewReader(reqBody)
req := httptest.NewRequest("POST", "/register", reqBodyReader)
router.ServeHTTP(rr, req)
// Check registration results
if rr.Code != 201 {
t.Errorf("Expected status code %d, got %d", 201, rr.Code)
}
// Login to get cookie
rr = httptest.NewRecorder()
currBodyLogin := struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: "testusername",
Password: "testpassword",
}
reqBody, err = json.Marshal(currBodyLogin)
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
reqBodyReader = bytes.NewReader(reqBody)
req = httptest.NewRequest("POST", "/login", reqBodyReader)
router.ServeHTTP(rr, req)
for _, currentTestCase := range testCases {
// Start checking Profile
rrProfile := httptest.NewRecorder()
if err != nil {
t.Errorf("Error creating a new request body: %v", err)
}
req = httptest.NewRequest("GET", "/me", nil)
if currentTestCase.cookieEnabled {
req.AddCookie(rr.Result().Cookies()[0])
}
router.ServeHTTP(rrProfile, req)
// Check Profile results
if rrProfile.Code != currentTestCase.statusCode {
t.Errorf("Expected status code %d, got %d", currentTestCase.statusCode, rrProfile.Code)
}
if currentTestCase.statusCode == 200 {
var resultsObj map[string]any
resBodyBytes, err := ioutil.ReadAll(rrProfile.Body)
if err != nil {
t.Errorf("Error reading response body: %v", err)
}
err = json.Unmarshal(resBodyBytes, &resultsObj)
if err != nil {
t.Errorf("Error unmarshalling response body: %v", err)
}
if resultsObj["username"] != "testusername" {
t.Errorf("Expected username %s, got %s", "testusername", resultsObj["username"])
}
}
}
}