From b9281c89737419216b710a87c31686d21adf86bc Mon Sep 17 00:00:00 2001 From: Nicolas Sterchele Date: Wed, 31 Aug 2022 23:08:52 +0200 Subject: initial commit Thanks to Ted Unangst for its work. Originally available here https://humungus.tedunangst.com/r/inks --- util.go | 281 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100755 util.go (limited to 'util.go') diff --git a/util.go b/util.go new file mode 100755 index 0000000..ec912fe --- /dev/null +++ b/util.go @@ -0,0 +1,281 @@ +// +// Copyright (c) 2019 Ted Unangst +// +// Permission to use, copy, modify, and distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +package main + +/* +#include + +void +termecho(int on) +{ + struct termios t; + tcgetattr(1, &t); + if (on) + t.c_lflag |= ECHO; + else + t.c_lflag &= ~ECHO; + tcsetattr(1, TCSADRAIN, &t); +} +*/ +import "C" + +import ( + "bufio" + "crypto/rand" + "crypto/rsa" + "crypto/sha512" + "database/sql" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/signal" + "strings" + + "golang.org/x/crypto/bcrypt" + _ "humungus.tedunangst.com/r/go-sqlite3" + "humungus.tedunangst.com/r/webs/httpsig" +) + +var savedstyleparams = make(map[string]string) + +func getstyleparam(file string) string { + if p, ok := savedstyleparams[file]; ok { + return p + } + data, err := ioutil.ReadFile(file) + if err != nil { + return "" + } + hasher := sha512.New() + hasher.Write(data) + + return fmt.Sprintf("?v=%.8x", hasher.Sum(nil)) +} + +var dbtimeformat = "2006-01-02 15:04:05" + +var alreadyopendb *sql.DB +var dbname = "inks.db" +var stmtConfig *sql.Stmt + +func initdb() { + schema, err := ioutil.ReadFile("schema.sql") + if err != nil { + log.Fatal(err) + } + _, err = os.Stat(dbname) + if err == nil { + log.Fatalf("%s already exists", dbname) + } + db, err := sql.Open("sqlite3", dbname) + if err != nil { + log.Fatal(err) + } + alreadyopendb = db + defer func() { + os.Remove(dbname) + os.Exit(1) + }() + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt) + go func() { + <-c + C.termecho(1) + fmt.Printf("\n") + os.Remove(dbname) + os.Exit(1) + }() + + for _, line := range strings.Split(string(schema), ";") { + _, err = db.Exec(line) + if err != nil { + log.Print(err) + return + } + } + defer db.Close() + r := bufio.NewReader(os.Stdin) + fmt.Printf("username: ") + name, err := r.ReadString('\n') + if err != nil { + log.Print(err) + return + } + name = name[:len(name)-1] + if len(name) < 1 { + log.Print("that's way too short") + return + } + C.termecho(0) + fmt.Printf("password: ") + pass, err := r.ReadString('\n') + C.termecho(1) + fmt.Printf("\n") + if err != nil { + log.Print(err) + return + } + pass = pass[:len(pass)-1] + if len(pass) < 6 { + log.Print("that's way too short") + return + } + hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12) + if err != nil { + log.Print(err) + return + } + _, err = db.Exec("insert into users (username, hash) values (?, ?)", name, hash) + if err != nil { + log.Print(err) + return + } + fmt.Printf("listen address: ") + addr, err := r.ReadString('\n') + if err != nil { + log.Print(err) + return + } + addr = addr[:len(addr)-1] + if len(addr) < 1 { + log.Print("that's way too short") + return + } + _, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr) + if err != nil { + log.Print(err) + return + } + fmt.Printf("server name: ") + addr, err = r.ReadString('\n') + if err != nil { + log.Print(err) + return + } + addr = addr[:len(addr)-1] + if len(addr) < 1 { + log.Print("that's way too short") + return + } + _, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr) + if err != nil { + log.Print(err) + return + } + var randbytes [16]byte + rand.Read(randbytes[:]) + key := fmt.Sprintf("%x", randbytes) + _, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key) + if err != nil { + log.Print(err) + return + } + setconfig("dbversion", dbVersion) + k, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Print(err) + return + } + pubkey, err := httpsig.EncodeKey(&k.PublicKey) + if err != nil { + log.Print(err) + return + } + _, err = db.Exec("insert into config (key, value) values (?, ?)", "pubkey", pubkey) + if err != nil { + log.Print(err) + return + } + seckey, err := httpsig.EncodeKey(k) + if err != nil { + log.Print(err) + return + } + _, err = db.Exec("insert into config (key, value) values (?, ?)", "seckey", seckey) + if err != nil { + log.Print(err) + return + } + prepareStatements(db) + db.Close() + fmt.Printf("done.\n") + os.Exit(0) +} + +func setconfig(key string, val interface{}) error { + db := opendatabase() + _, err := db.Exec("insert into config (key, value) values (?, ?)", key, val) + return err +} + +func opendatabase() *sql.DB { + if alreadyopendb != nil { + return alreadyopendb + } + var err error + _, err = os.Stat(dbname) + if err != nil { + log.Fatalf("unable to open database: %s", err) + } + db, err := sql.Open("sqlite3", dbname) + if err != nil { + log.Fatalf("unable to open database: %s", err) + } + stmtConfig, err = db.Prepare("select value from config where key = ?") + if err != nil { + log.Fatal(err) + } + alreadyopendb = db + return db +} + +func getconfig(key string, value interface{}) error { + row := stmtConfig.QueryRow(key) + err := row.Scan(value) + if err == sql.ErrNoRows { + err = nil + } + return err +} + +func openListener() (net.Listener, error) { + var listenAddr string + err := getconfig("listenaddr", &listenAddr) + if err != nil { + return nil, err + } + if listenAddr == "" { + return nil, fmt.Errorf("must have listenaddr") + } + proto := "tcp" + if listenAddr[0] == '/' { + proto = "unix" + err := os.Remove(listenAddr) + if err != nil && !os.IsNotExist(err) { + log.Printf("unable to unlink socket: %s", err) + } + } + listener, err := net.Listen(proto, listenAddr) + if err != nil { + return nil, err + } + if proto == "unix" { + os.Chmod(listenAddr, 0777) + } + return listener, nil +} -- cgit v1.2.3