package user import ( "bytes" "encoding/json" "io/ioutil" "net/http" "net/http/httptest" "testing" "git.samuelpua.com/telboon/ktm-train-bot/backend/internal/common" ) func TestUserRoutesRegistration(t *testing.T) { testCases := []struct { allowRegistrationFlag bool allowRegistrationFlagSet bool statusCode int }{ { allowRegistrationFlag: true, allowRegistrationFlagSet: true, statusCode: 201, }, { allowRegistrationFlag: false, allowRegistrationFlagSet: true, statusCode: 404, }, { allowRegistrationFlag: false, allowRegistrationFlagSet: false, statusCode: 404, }, } db := common.TestDBInit() defer common.DestroyTestingDB(db) db.AutoMigrate(&User{}) db.AutoMigrate(&Profile{}) db.AutoMigrate(&Session{}) for _, currTestCase := range testCases { if currTestCase.allowRegistrationFlagSet { if currTestCase.allowRegistrationFlag { t.Setenv("ALLOW_REGISTRATION", "true") } else { t.Setenv("ALLOW_REGISTRATION", "false") } } else { t.Setenv("ALLOW_REGISTRATION", "") } router := UserRoutes(db) rr := httptest.NewRecorder() reqBody := struct { Username string `json:"username"` Password string `json:"password"` }{ Username: "testusername", Password: "testpassword", } reqBodyBytes, err := json.Marshal(reqBody) if err != nil { t.Errorf("Error creating a new request body: %v", err) } reqBodyReader := bytes.NewReader(reqBodyBytes) req, err := http.NewRequest("POST", "/register", reqBodyReader) if err != nil { t.Errorf("Error creating a new request: %v", err) } router.ServeHTTP(rr, req) currStatusCode := rr.Result().StatusCode if currStatusCode != currTestCase.statusCode { t.Errorf("Wrong status code: Expected - %d. Got - %d", currTestCase.statusCode, rr.Result().StatusCode) } } } func TestUserRoutesLoggedIn(t *testing.T) { db := common.TestDBInit() defer common.DestroyTestingDB(db) db.AutoMigrate(&User{}) db.AutoMigrate(&Profile{}) db.AutoMigrate(&Session{}) t.Setenv("ALLOW_REGISTRATION", "true") router := UserRoutes(db) rr := httptest.NewRecorder() // Testing /me req, err := http.NewRequest("GET", "/me", nil) if err != nil { t.Errorf("Error creating a new request: %v", err) } router.ServeHTTP(rr, req) if rr.Result().StatusCode != 500 { t.Errorf("Wrong status code: Expected - %d. Got - %d", 500, rr.Result().StatusCode) } responseBytes, err := ioutil.ReadAll(rr.Result().Body) if err != nil { t.Errorf("Error reading response body: %v", err) } var results map[string]any err = json.Unmarshal(responseBytes, &results) if err != nil { t.Errorf("Error decoding response body: %v", err) } if results["error"] != "user not logged in" { t.Errorf("Error for %s not correct", "/me") } // Testing /profile rr = httptest.NewRecorder() emptyJSON := bytes.NewReader([]byte("{}")) req, err = http.NewRequest("PUT", "/profile", emptyJSON) if err != nil { t.Errorf("Error creating a new request: %v", err) } router.ServeHTTP(rr, req) if rr.Result().StatusCode != 500 { t.Errorf("Wrong status code: Expected - %d. Got - %d", 500, rr.Result().StatusCode) } responseBytes, err = ioutil.ReadAll(rr.Result().Body) if err != nil { t.Errorf("Error reading response body: %v", err) } err = json.Unmarshal(responseBytes, &results) if err != nil { t.Errorf("Error decoding response body: %v", err) } if results["error"] != "user not logged in" { t.Errorf("Error for %s not correct", "/me") } }