diff --git a/klice.go b/klice.go index 7c81813..ba1d7ec 100644 --- a/klice.go +++ b/klice.go @@ -48,8 +48,14 @@ func loginHandler(w http.ResponseWriter, r *http.Request) { } http.SetCookie(w, cookie) - http.Redirect(w, r, "/", http.StatusSeeOther) - return + redir, err := r.Cookie("url") + if err == nil { + redir.MaxAge = -1 + http.SetCookie(w, redir) + http.Redirect(w, r, redir.Value, http.StatusSeeOther) + } else { + http.Redirect(w, r, "/team", http.StatusSeeOther) + } } else if r.Method == http.MethodGet { loginPage, err := os.Open("login.html") if err != nil { @@ -59,41 +65,67 @@ func loginHandler(w http.ResponseWriter, r *http.Request) { defer loginPage.Close() io.Copy(w, loginPage) - return } } func logoutHandler(w http.ResponseWriter, r *http.Request) { cookie := &http.Cookie{ - Name: "session_id", - Value: "", - Path: "/", - MaxAge: -1, + Name: "session_id", + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, } http.SetCookie(w, cookie) http.Redirect(w, r, "/login", http.StatusSeeOther) } -func isLoggedIn(r *http.Request) (bool, int) { +func isLoggedIn(w http.ResponseWriter, r *http.Request) (bool, int) { + var exist bool = true + var teamID int cookie, err := r.Cookie("session_id") if err != nil { - return false, 0 + exist = false + } else { + err = db.QueryRow("SELECT id FROM teams WHERE password = ?", cookie.Value).Scan(&teamID) + if err == sql.ErrNoRows { + exist = false + } else if err != nil { + exist = false + } } - var teamID int - err = db.QueryRow("SELECT id FROM teams WHERE password = ?", cookie.Value).Scan(&teamID) - if err == sql.ErrNoRows { - return false, 0 - } else if err != nil { + if !exist { + redir := &http.Cookie{ + Name: "url", + Value: r.URL.String(), + MaxAge: 300, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + } + http.SetCookie(w, redir) + http.Redirect(w, r, "/login", http.StatusSeeOther) return false, 0 } - return true, teamID } +func teamInfoHandler(w http.ResponseWriter, r *http.Request) { + if loggedIn, teamID := isLoggedIn(w, r); loggedIn { + var teamName, city string + err := db.QueryRow("SELECT name, city FROM teams WHERE id = ?", teamID).Scan(&teamName, &city) + if err != nil { + http.Error(w, "Could not retrieve team information", http.StatusInternalServerError) + return + } + fmt.Fprintf(w, "Team Name: %s, City: %s", teamName, city) + } +} + func main() { - // Set up a new SQLite database var err error db, err = sql.Open("sqlite3", "./klice.db?_fk=on") if err != nil { @@ -104,19 +136,7 @@ func main() { http.HandleFunc("/login", loginHandler) http.HandleFunc("/logout", logoutHandler) - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - if loggedIn, teamID := isLoggedIn(r); loggedIn { - var teamName string - err := db.QueryRow("SELECT name FROM teams WHERE id = ?", teamID).Scan(&teamName) - if err != nil { - http.Error(w, "Could not retrieve team name", http.StatusInternalServerError) - return - } - fmt.Fprintf(w, "Welcome back, team %s!", teamName) - } else { - http.Redirect(w, r, "/login", http.StatusSeeOther) - } - }) + http.HandleFunc("/team", teamInfoHandler) fmt.Println("Server started at :8080") http.ListenAndServe(":8080", nil)