You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

259 lines
5.2 KiB
Go

package main
import (
"context"
"crypto/md5"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/http/cgi"
"strings"
"faculty_media_report/api"
"faculty_media_report/dbi"
"faculty_media_report/pages"
)
const cas_url = "https://idp.login.iu.edu/idp/profile/cas/login?service=https://app.law.indiana.edu/faculty/activity/cas"
const validation_url_template = "https://idp.login.iu.edu/idp/profile/cas/serviceValidate?ticket=%s&service=https://app.law.indiana.edu/faculty/activity/cas"
var xml_declaration = `<?xml version="1.0" encoding="UTF-8"?>`
type CASAuthenticationSuccess struct {
XMLName xml.Name `xml:"authenticationSuccess"`
User string `xml:"user"`
}
type CASResponse struct {
XMLName xml.Name `xml:"serviceResponse"`
AuthenticationSuccess CASAuthenticationSuccess `xml:"authenticationSuccess"`
}
func parse_username(response_body []byte) string {
var cas_response CASResponse
xmlstring := string(response_body)
xmlstring = strings.Replace(xmlstring, xml_declaration, "", 1)
xmlstring = strings.TrimSpace(xmlstring)
raw_xml_data := []byte(xmlstring)
e := xml.Unmarshal(raw_xml_data, &cas_response)
if e != nil {
return ""
}
return cas_response.AuthenticationSuccess.User
}
func validate_cas(cas_ticket string) (string, error) {
validation_url := fmt.Sprintf(validation_url_template, cas_ticket)
username := ""
response, e := http.Get(validation_url)
if e == nil {
response_body, er := io.ReadAll(response.Body)
if er != nil {
e = er
// return error below
} else {
username = parse_username(response_body)
if er != nil {
e = er
}
//if username == "" {
// log_to_file(string(response_body))
//}
}
defer response.Body.Close()
}
if username != "" {
return username, nil
}
return username, e
}
func getCas(w http.ResponseWriter, r *http.Request){
fail := func() {
writeHTML(w, pages.LoginPage("Username or Password is incorrect."))
}
cas_ticket := r.URL.Query().Get("ticket")
if cas_ticket != "" { // Validate CAS Ticket (and display page if valid)
username, e := validate_cas(cas_ticket)
username = strings.ToLower(username)
if e != nil {
err_message := fmt.Sprintf("VALIDATION ERROR: %s\n", e)
w.Write([]byte(err_message))
return
} else if username == "" {
w.Write([]byte("CAS ERROR: user not found"))
return
} else {
db, err := dbi.GetDbConn()
if err != nil {
fail()
return
}
defer db.Close()
conn, err := db.Conn(context.Background())
if err != nil {
fail()
return
}
defer conn.Close()
user, e := dbi.GetUser(conn, username)
if e != nil {
login_error_message := fmt.Sprintf("User error for user %s: %s", e, username)
w.Write([]byte(login_error_message))
return
}
token, err := dbi.GenJWT(user)
if err != nil {
fail()
return
}
if user.Status == "admin" {
items, err := dbi.GetReportedItems(conn)
if err != nil {
fail()
return
}
html, err := pages.DashboardPage(token, items)
if err != nil {
fail()
return
}
writeHTML(w, html)
return
} else if user.Status == "faculty" {
writeHTML(w, pages.MainFormPage(token))
return
} else {
w.Write([]byte("User has invalid status"))
return
}
}
} else { //No CAS Ticket - validate user
http.Redirect(w, r, cas_url, http.StatusFound)
}
}
func writeHTML(w http.ResponseWriter, html string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, html)
}
func handleCas(w http.ResponseWriter, r *http.Request) {
}
func handleLoginGet(w http.ResponseWriter, r *http.Request) {
writeHTML(w, pages.LoginPage(""))
}
func handleLoginPost(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
password := r.FormValue("password")
fail := func() {
writeHTML(w, pages.LoginPage("Username or Password is incorrect."))
}
if username == "" || password == "" {
fail()
return
}
db, err := dbi.GetDbConn()
if err != nil {
fail()
return
}
defer db.Close()
conn, err := db.Conn(context.Background())
if err != nil {
fail()
return
}
defer conn.Close()
user, err := dbi.GetUser(conn, username)
if err != nil {
fail()
return
}
hash := md5.Sum([]byte(password))
if fmt.Sprintf("%x", hash) != user.Password {
fail()
return
}
if user.Status == "admin" {
token, err := dbi.GenJWT(user)
if err != nil {
fail()
return
}
items, err := dbi.GetReportedItems(conn)
if err != nil {
fail()
return
}
html, err := pages.DashboardPage(token, items)
if err != nil {
fail()
return
}
writeHTML(w, html)
return
}
token, err := dbi.GenJWT(user)
if err != nil {
fail()
return
}
writeHTML(w, pages.MainFormPage(token))
}
func getPortal(w http.ResponseWriter, r *http.Request) {
_, e := w.Write([]byte(pages.LandingPage))
if e != nil {
content := fmt.Sprintf(pages.ErrorPage, "Unable to display landing page")
writeHTML(w, content)
}
}
func main() {
mux := http.NewServeMux()
mux.HandleFunc("GET /faculty/activity/cas", handleCas)
mux.HandleFunc("GET /faculty/activity/portal", getPortal)
mux.HandleFunc("GET /faculty/activity/login", handleLoginGet)
mux.HandleFunc("POST /faculty/activity/login", handleLoginPost)
mux.HandleFunc("POST /faculty/activity/api", api.HandleAPI)
cgi.Serve(mux)
}