Use cookies for session storage

Remove the server side session storage and store all the session related data
in the client side cookies. This decreases the exposure of the auth tokens.
It also simplifies the installation process as bloat no longer requires write
access to the filesystem.

This is a breaking change, all the existing sessions will stop working.
This commit is contained in:
r 2022-10-25 13:40:49 +00:00
parent b4ccde54a7
commit 887ed241d6
14 changed files with 225 additions and 495 deletions

10
INSTALL
View File

@ -23,16 +23,8 @@ most cases, you only need to change the value of "client_website".
# cp bloat.gen.conf /etc/bloat.conf
# $EDITOR /etc/bloat.conf
4. Create database directory
Create a directory to store session information. Optionally, create a user
to run bloat and change the ownership of the database directory accordingly.
# mkdir /var/bloat
# useradd _bloat
# chown -R _bloat:_bloat /var/bloat
Replace /var/bloat with the value you specified in the config file.
5. Run the binary
# su _bloat -c bloat
$ bloat
Now you should create an init script to automatically start bloat at system
startup.

View File

@ -10,7 +10,6 @@ SRC=main.go \
mastodon/*.go \
model/*.go \
renderer/*.go \
repo/*.go \
service/*.go \
util/*.go \
@ -18,8 +17,7 @@ all: bloat
bloat: $(SRC) $(TMPL)
$(GO) build $(GOFLAGS) -o bloat main.go
sed -e "s%=database%=/var/bloat%g" \
-e "s%=templates%=$(SHAREPATH)/templates%g" \
sed -e "s%=templates%=$(SHAREPATH)/templates%g" \
-e "s%=static%=$(SHAREPATH)/static%g" \
< bloat.conf > bloat.gen.conf

View File

@ -3,10 +3,6 @@
# - Key and Value are separated by a single '='
# - Leading and trailing white spaces in Key and Value are ignored
# - Quoting and multi-line values are not supported
#
# Changing values of client_name, client_scope or client_website will cause
# previously generated access tokens and client tokens to be invalid. Issue the
# `rm -r database_path/*` command to clean the database afterwards.
# Address to listen to. Value can be of "HOSTNAME:PORT" or "IP:PORT" form. In
# case of empty HOSTNAME or IP, "0.0.0.0:PORT" is used.
@ -25,9 +21,6 @@ client_name=bloat
# See https://docs.joinmastodon.org/api/oauth-scopes/
client_scope=read write follow
# Path of database directory. It's used to store session information.
database_path=database
# Path of directory containing template files.
templates_path=templates

View File

@ -18,7 +18,6 @@ type config struct {
SingleInstance string
StaticDirectory string
TemplatesPath string
DatabasePath string
CustomCSS string
PostFormats []model.PostFormat
LogFile string
@ -30,8 +29,7 @@ func (c *config) IsValid() bool {
len(c.ClientScope) < 1 ||
len(c.ClientWebsite) < 1 ||
len(c.StaticDirectory) < 1 ||
len(c.TemplatesPath) < 1 ||
len(c.DatabasePath) < 1 {
len(c.TemplatesPath) < 1 {
return false
}
return true
@ -75,10 +73,10 @@ func Parse(r io.Reader) (c *config, err error) {
c.StaticDirectory = val
case "templates_path":
c.TemplatesPath = val
case "database_path":
c.DatabasePath = val
case "custom_css":
c.CustomCSS = val
case "database_path":
// ignore
case "post_formats":
vals := strings.Split(val, ",")
var formats []model.PostFormat

24
main.go
View File

@ -12,9 +12,7 @@ import (
"bloat/config"
"bloat/renderer"
"bloat/repo"
"bloat/service"
"bloat/util"
)
var (
@ -48,26 +46,6 @@ func main() {
errExit(err)
}
err = os.Mkdir(config.DatabasePath, 0755)
if err != nil && !os.IsExist(err) {
errExit(err)
}
sessionDBPath := filepath.Join(config.DatabasePath, "session")
sessionDB, err := util.NewDatabse(sessionDBPath)
if err != nil {
errExit(err)
}
appDBPath := filepath.Join(config.DatabasePath, "app")
appDB, err := util.NewDatabse(appDBPath)
if err != nil {
errExit(err)
}
sessionRepo := repo.NewSessionRepo(sessionDB)
appRepo := repo.NewAppRepo(appDB)
customCSS := config.CustomCSS
if len(customCSS) > 0 && !strings.HasPrefix(customCSS, "http://") &&
!strings.HasPrefix(customCSS, "https://") {
@ -89,7 +67,7 @@ func main() {
s := service.NewService(config.ClientName, config.ClientScope,
config.ClientWebsite, customCSS, config.SingleInstance,
config.PostFormats, renderer, sessionRepo, appRepo)
config.PostFormats, renderer)
handler := service.NewHandler(s, logger, config.StaticDirectory)
logger.Println("listening on", config.ListenAddress)

View File

@ -1,21 +0,0 @@
package model
import (
"errors"
)
var (
ErrAppNotFound = errors.New("app not found")
)
type App struct {
InstanceDomain string `json:"instance_domain"`
InstanceURL string `json:"instance_url"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
type AppRepo interface {
Add(app App) (err error)
Get(instanceDomain string) (app App, err error)
}

View File

@ -1,28 +1,48 @@
package model
import (
"errors"
)
var (
ErrSessionNotFound = errors.New("session not found")
)
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
InstanceDomain string `json:"instance_domain"`
AccessToken string `json:"access_token"`
CSRFToken string `json:"csrf_token"`
Settings Settings `json:"settings"`
}
type SessionRepo interface {
Add(session Session) (err error)
Get(sessionID string) (session Session, err error)
Remove(sessionID string)
ID string `json:"id,omitempty"`
UserID string `json:"uid,omitempty"`
Instance string `json:"ins,omitempty"`
ClientID string `json:"cid,omitempty"`
ClientSecret string `json:"cs,omitempty"`
AccessToken string `json:"at,omitempty"`
CSRFToken string `json:"csrf,omitempty"`
Settings Settings `json:"sett,omitempty"`
}
func (s Session) IsLoggedIn() bool {
return len(s.AccessToken) > 0
}
type Settings struct {
DefaultVisibility string `json:"dv,omitempty"`
DefaultFormat string `json:"df,omitempty"`
CopyScope bool `json:"cs,omitempty"`
ThreadInNewTab bool `json:"tnt,omitempty"`
HideAttachments bool `json:"ha,omitempty"`
MaskNSFW bool `json:"mn,omitempty"`
NotificationInterval int `json:"ni,omitempty"`
FluorideMode bool `json:"fm,omitempty"`
DarkMode bool `json:"dm,omitempty"`
AntiDopamineMode bool `json:"adm,omitempty"`
HideUnsupportedNotifs bool `json:"hun,omitempty"`
CSS string `json:"css,omitempty"`
}
func NewSettings() *Settings {
return &Settings{
DefaultVisibility: "public",
DefaultFormat: "",
CopyScope: true,
ThreadInNewTab: false,
HideAttachments: false,
MaskNSFW: true,
NotificationInterval: 0,
FluorideMode: false,
DarkMode: false,
AntiDopamineMode: false,
HideUnsupportedNotifs: false,
CSS: "",
}
}

View File

@ -1,33 +0,0 @@
package model
type Settings struct {
DefaultVisibility string `json:"default_visibility"`
DefaultFormat string `json:"default_format"`
CopyScope bool `json:"copy_scope"`
ThreadInNewTab bool `json:"thread_in_new_tab"`
HideAttachments bool `json:"hide_attachments"`
MaskNSFW bool `json:"mask_nfsw"`
NotificationInterval int `json:"notifications_interval"`
FluorideMode bool `json:"fluoride_mode"`
DarkMode bool `json:"dark_mode"`
AntiDopamineMode bool `json:"anti_dopamine_mode"`
HideUnsupportedNotifs bool `json:"hide_unsupported_notifs"`
CSS string `json:"css"`
}
func NewSettings() *Settings {
return &Settings{
DefaultVisibility: "public",
DefaultFormat: "",
CopyScope: true,
ThreadInNewTab: false,
HideAttachments: false,
MaskNSFW: true,
NotificationInterval: 0,
FluorideMode: false,
DarkMode: false,
AntiDopamineMode: false,
HideUnsupportedNotifs: false,
CSS: "",
}
}

View File

@ -1,42 +0,0 @@
package repo
import (
"encoding/json"
"bloat/util"
"bloat/model"
)
type appRepo struct {
db *util.Database
}
func NewAppRepo(db *util.Database) *appRepo {
return &appRepo{
db: db,
}
}
func (repo *appRepo) Add(a model.App) (err error) {
data, err := json.Marshal(a)
if err != nil {
return
}
err = repo.db.Set(a.InstanceDomain, data)
return
}
func (repo *appRepo) Get(instanceDomain string) (a model.App, err error) {
data, err := repo.db.Get(instanceDomain)
if err != nil {
err = model.ErrAppNotFound
return
}
err = json.Unmarshal(data, &a)
if err != nil {
return
}
return
}

View File

@ -1,47 +0,0 @@
package repo
import (
"encoding/json"
"bloat/util"
"bloat/model"
)
type sessionRepo struct {
db *util.Database
}
func NewSessionRepo(db *util.Database) *sessionRepo {
return &sessionRepo{
db: db,
}
}
func (repo *sessionRepo) Add(s model.Session) (err error) {
data, err := json.Marshal(s)
if err != nil {
return
}
err = repo.db.Set(s.ID, data)
return
}
func (repo *sessionRepo) Get(id string) (s model.Session, err error) {
data, err := repo.db.Get(id)
if err != nil {
err = model.ErrSessionNotFound
return
}
err = json.Unmarshal(data, &s)
if err != nil {
return
}
return
}
func (repo *sessionRepo) Remove(id string) {
repo.db.Remove(id)
return
}

111
service/client.go Normal file
View File

@ -0,0 +1,111 @@
package service
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"strings"
"time"
"bloat/mastodon"
"bloat/model"
"bloat/renderer"
)
type client struct {
*mastodon.Client
w http.ResponseWriter
r *http.Request
s *model.Session
csrf string
ctx context.Context
rctx *renderer.Context
}
func (c *client) setSession(sess *model.Session) error {
var sb strings.Builder
bw := base64.NewEncoder(base64.URLEncoding, &sb)
err := json.NewEncoder(bw).Encode(sess)
bw.Close()
if err != nil {
return err
}
http.SetCookie(c.w, &http.Cookie{
Name: "session",
Value: sb.String(),
Expires: time.Now().Add(365 * 24 * time.Hour),
})
return nil
}
func (c *client) getSession() (sess *model.Session, err error) {
cookie, _ := c.r.Cookie("session")
if cookie == nil {
return nil, errInvalidSession
}
br := base64.NewDecoder(base64.URLEncoding, strings.NewReader(cookie.Value))
err = json.NewDecoder(br).Decode(&sess)
return
}
func (c *client) unsetSession() {
http.SetCookie(c.w, &http.Cookie{
Name: "session",
Value: "",
Expires: time.Now(),
})
}
func (c *client) writeJson(data interface{}) error {
return json.NewEncoder(c.w).Encode(map[string]interface{}{
"data": data,
})
}
func (c *client) redirect(url string) {
c.w.Header().Add("Location", url)
c.w.WriteHeader(http.StatusFound)
}
func (c *client) authenticate(t int) (err error) {
csrf := c.r.FormValue("csrf_token")
ref := c.r.URL.RequestURI()
defer func() {
if c.s == nil {
c.s = &model.Session{
Settings: *model.NewSettings(),
}
}
c.rctx = &renderer.Context{
HideAttachments: c.s.Settings.HideAttachments,
MaskNSFW: c.s.Settings.MaskNSFW,
ThreadInNewTab: c.s.Settings.ThreadInNewTab,
FluorideMode: c.s.Settings.FluorideMode,
DarkMode: c.s.Settings.DarkMode,
CSRFToken: c.s.CSRFToken,
UserID: c.s.UserID,
AntiDopamineMode: c.s.Settings.AntiDopamineMode,
UserCSS: c.s.Settings.CSS,
Referrer: ref,
}
}()
if t < SESSION {
return
}
sess, err := c.getSession()
if err != nil {
return err
}
c.s = sess
c.Client = mastodon.NewClient(&mastodon.Config{
Server: "https://" + c.s.Instance,
ClientID: c.s.ClientID,
ClientSecret: c.s.ClientSecret,
AccessToken: c.s.AccessToken,
})
if t >= CSRF && (len(csrf) < 1 || csrf != c.s.CSRFToken) {
return errInvalidCSRFToken
}
return
}

View File

@ -27,14 +27,11 @@ type service struct {
instance string
postFormats []model.PostFormat
renderer renderer.Renderer
sessionRepo model.SessionRepo
appRepo model.AppRepo
}
func NewService(cname string, cscope string, cwebsite string,
css string, instance string, postFormats []model.PostFormat,
renderer renderer.Renderer, sessionRepo model.SessionRepo,
appRepo model.AppRepo) *service {
renderer renderer.Renderer) *service {
return &service{
cname: cname,
cscope: cscope,
@ -43,57 +40,9 @@ func NewService(cname string, cscope string, cwebsite string,
instance: instance,
postFormats: postFormats,
renderer: renderer,
sessionRepo: sessionRepo,
appRepo: appRepo,
}
}
func (s *service) authenticate(c *client, sid string, csrf string, ref string, t int) (err error) {
var sett *model.Settings
defer func() {
if sett == nil {
sett = model.NewSettings()
}
c.rctx = &renderer.Context{
HideAttachments: sett.HideAttachments,
MaskNSFW: sett.MaskNSFW,
ThreadInNewTab: sett.ThreadInNewTab,
FluorideMode: sett.FluorideMode,
DarkMode: sett.DarkMode,
CSRFToken: c.s.CSRFToken,
UserID: c.s.UserID,
AntiDopamineMode: sett.AntiDopamineMode,
UserCSS: sett.CSS,
Referrer: ref,
}
}()
if t < SESSION {
return
}
if len(sid) < 1 {
return errInvalidSession
}
c.s, err = s.sessionRepo.Get(sid)
if err != nil {
return errInvalidSession
}
sett = &c.s.Settings
app, err := s.appRepo.Get(c.s.InstanceDomain)
if err != nil {
return err
}
c.Client = mastodon.NewClient(&mastodon.Config{
Server: app.InstanceURL,
ClientID: app.ClientID,
ClientSecret: app.ClientSecret,
AccessToken: c.s.AccessToken,
})
if t >= CSRF && (len(csrf) < 1 || csrf != c.s.CSRFToken) {
return errInvalidCSRFToken
}
return
}
func (s *service) cdata(c *client, title string, count int, rinterval int,
target string) (data *renderer.CommonData) {
data = &renderer.CommonData{
@ -820,7 +769,7 @@ func (s *service) SingleInstance() (instance string, ok bool) {
return
}
func (s *service) NewSession(c *client, instance string) (rurl string, sid string, err error) {
func (s *service) NewSession(c *client, instance string) (rurl string, sess *model.Session, err error) {
var instanceURL string
if strings.HasPrefix(instance, "https://") {
instanceURL = instance
@ -829,7 +778,7 @@ func (s *service) NewSession(c *client, instance string) (rurl string, sid strin
instanceURL = "https://" + instance
}
sid, err = util.NewSessionID()
sid, err := util.NewSessionID()
if err != nil {
return
}
@ -838,42 +787,23 @@ func (s *service) NewSession(c *client, instance string) (rurl string, sid strin
return
}
sess := model.Session{
ID: sid,
InstanceDomain: instance,
CSRFToken: csrf,
Settings: *model.NewSettings(),
}
err = s.sessionRepo.Add(sess)
app, err := mastodon.RegisterApp(c.ctx, &mastodon.AppConfig{
Server: instanceURL,
ClientName: s.cname,
Scopes: s.cscope,
Website: s.cwebsite,
RedirectURIs: s.cwebsite + "/oauth_callback",
})
if err != nil {
return
}
app, err := s.appRepo.Get(instance)
if err != nil {
if err != model.ErrAppNotFound {
return
}
mastoApp, err := mastodon.RegisterApp(c.ctx, &mastodon.AppConfig{
Server: instanceURL,
ClientName: s.cname,
Scopes: s.cscope,
Website: s.cwebsite,
RedirectURIs: s.cwebsite + "/oauth_callback",
})
if err != nil {
return "", "", err
}
app = model.App{
InstanceDomain: instance,
InstanceURL: instanceURL,
ClientID: mastoApp.ClientID,
ClientSecret: mastoApp.ClientSecret,
}
err = s.appRepo.Add(app)
if err != nil {
return "", "", err
}
sess = &model.Session{
ID: sid,
Instance: instance,
ClientID: app.ClientID,
ClientSecret: app.ClientSecret,
CSRFToken: csrf,
Settings: *model.NewSettings(),
}
u, err := url.Parse("/oauth/authorize")
@ -907,12 +837,7 @@ func (s *service) Signin(c *client, code string) (err error) {
}
c.s.AccessToken = c.GetAccessToken(c.ctx)
c.s.UserID = u.ID
return s.sessionRepo.Add(c.s)
}
func (s *service) Signout(c *client) (err error) {
s.sessionRepo.Remove(c.s.ID)
return
return c.setSession(c.s)
}
func (s *service) Post(c *client, content string, replyToID string,
@ -1044,12 +969,8 @@ func (s *service) SaveSettings(c *client, settings *model.Settings) (err error)
if len(settings.CSS) > 1<<20 {
return errInvalidArgument
}
sess, err := s.sessionRepo.Get(c.s.ID)
if err != nil {
return
}
sess.Settings = *settings
return s.sessionRepo.Add(sess)
c.s.Settings = *settings
return c.setSession(c.s)
}
func (s *service) MuteConversation(c *client, id string) (err error) {

View File

@ -1,24 +1,17 @@
package service
import (
"context"
"encoding/json"
"log"
"net/http"
"strconv"
"time"
"bloat/mastodon"
"bloat/model"
"bloat/renderer"
"github.com/gorilla/mux"
)
const (
sessionExp = 365 * 24 * time.Hour
)
const (
HTML int = iota
JSON
@ -30,35 +23,6 @@ const (
CSRF
)
type client struct {
*mastodon.Client
w http.ResponseWriter
r *http.Request
s model.Session
csrf string
ctx context.Context
rctx *renderer.Context
}
func setSessionCookie(w http.ResponseWriter, sid string, exp time.Duration) {
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sid,
Expires: time.Now().Add(exp),
})
}
func writeJson(c *client, data interface{}) error {
return json.NewEncoder(c.w).Encode(map[string]interface{}{
"data": data,
})
}
func redirect(c *client, url string) {
c.w.Header().Add("Location", url)
c.w.WriteHeader(http.StatusFound)
}
func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
r := mux.NewRouter()
@ -75,16 +39,6 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
}
}
authenticate := func(c *client, t int) error {
var sid string
if cookie, _ := c.r.Cookie("session_id"); cookie != nil {
sid = cookie.Value
}
csrf := c.r.FormValue("csrf_token")
ref := c.r.URL.RequestURI()
return s.authenticate(c, sid, csrf, ref, t)
}
handle := func(f func(c *client) error, at int, rt int) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
var err error
@ -108,7 +62,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
}
c.w.Header().Add("Content-Type", ct)
err = authenticate(c, at)
err = c.authenticate(at)
if err != nil {
writeError(c, err, rt, req.Method == http.MethodGet)
return
@ -123,16 +77,16 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
}
rootPage := handle(func(c *client) error {
err := authenticate(c, SESSION)
err := c.authenticate(SESSION)
if err != nil {
if err == errInvalidSession {
redirect(c, "/signin")
c.redirect("/signin")
return nil
}
return err
}
if !c.s.IsLoggedIn() {
redirect(c, "/signin")
c.redirect("/signin")
return nil
}
return s.RootPage(c)
@ -147,12 +101,12 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if !ok {
return s.SigninPage(c)
}
url, sid, err := s.NewSession(c, instance)
url, sess, err := s.NewSession(c, instance)
if err != nil {
return err
}
setSessionCookie(c.w, sid, sessionExp)
redirect(c, url)
c.setSession(sess)
c.redirect(url)
return nil
}, NOAUTH, HTML)
@ -167,7 +121,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
}, SESSION, HTML)
defaultTimelinePage := handle(func(c *client) error {
redirect(c, "/timeline/home")
c.redirect("/timeline/home")
return nil
}, SESSION, HTML)
@ -243,12 +197,12 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
signin := handle(func(c *client) error {
instance := c.r.FormValue("instance")
url, sid, err := s.NewSession(c, instance)
url, sess, err := s.NewSession(c, instance)
if err != nil {
return err
}
setSessionCookie(c.w, sid, sessionExp)
redirect(c, url)
c.setSession(sess)
c.redirect(url)
return nil
}, NOAUTH, HTML)
@ -259,7 +213,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, "/")
c.redirect("/")
return nil
}, SESSION, HTML)
@ -287,7 +241,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
} else {
location = c.r.FormValue("referrer")
}
redirect(c, location)
c.redirect(location)
return nil
}, CSRF, HTML)
@ -301,7 +255,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 {
id = rid
}
redirect(c, c.r.FormValue("referrer")+"#status-"+id)
c.redirect(c.r.FormValue("referrer") + "#status-" + id)
return nil
}, CSRF, HTML)
@ -315,7 +269,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 {
id = rid
}
redirect(c, c.r.FormValue("referrer")+"#status-"+id)
c.redirect(c.r.FormValue("referrer") + "#status-" + id)
return nil
}, CSRF, HTML)
@ -329,7 +283,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 {
id = rid
}
redirect(c, c.r.FormValue("referrer")+"#status-"+id)
c.redirect(c.r.FormValue("referrer") + "#status-" + id)
return nil
}, CSRF, HTML)
@ -343,7 +297,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 {
id = rid
}
redirect(c, c.r.FormValue("referrer")+"#status-"+id)
c.redirect(c.r.FormValue("referrer") + "#status-" + id)
return nil
}, CSRF, HTML)
@ -355,7 +309,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer")+"#status-"+statusID)
c.redirect(c.r.FormValue("referrer") + "#status-" + statusID)
return nil
}, CSRF, HTML)
@ -371,7 +325,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -381,7 +335,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -391,7 +345,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -401,7 +355,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -417,7 +371,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -427,7 +381,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -437,7 +391,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -447,7 +401,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -457,7 +411,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -467,7 +421,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -504,7 +458,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, "/")
c.redirect("/")
return nil
}, CSRF, HTML)
@ -514,7 +468,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -524,7 +478,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -534,7 +488,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -545,7 +499,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -559,7 +513,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 {
id = rid
}
redirect(c, c.r.FormValue("referrer")+"#status-"+id)
c.redirect(c.r.FormValue("referrer") + "#status-" + id)
return nil
}, CSRF, HTML)
@ -573,7 +527,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if len(rid) > 0 {
id = rid
}
redirect(c, c.r.FormValue("referrer")+"#status-"+id)
c.redirect(c.r.FormValue("referrer") + "#status-" + id)
return nil
}, CSRF, HTML)
@ -584,7 +538,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -594,7 +548,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -608,7 +562,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -618,7 +572,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -629,7 +583,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -648,7 +602,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
@ -660,14 +614,13 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
redirect(c, c.r.FormValue("referrer"))
c.redirect(c.r.FormValue("referrer"))
return nil
}, CSRF, HTML)
signout := handle(func(c *client) error {
s.Signout(c)
setSessionCookie(c.w, "", 0)
redirect(c, "/")
c.unsetSession()
c.redirect("/")
return nil
}, CSRF, HTML)
@ -677,7 +630,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
return writeJson(c, count)
return c.writeJson(count)
}, CSRF, JSON)
fUnlike := handle(func(c *client) error {
@ -686,7 +639,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
return writeJson(c, count)
return c.writeJson(count)
}, CSRF, JSON)
fRetweet := handle(func(c *client) error {
@ -695,7 +648,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
return writeJson(c, count)
return c.writeJson(count)
}, CSRF, JSON)
fUnretweet := handle(func(c *client) error {
@ -704,7 +657,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler {
if err != nil {
return err
}
return writeJson(c, count)
return c.writeJson(count)
}, CSRF, JSON)
r.HandleFunc("/", rootPage).Methods(http.MethodGet)

View File

@ -1,91 +0,0 @@
package util
import (
"errors"
"io/ioutil"
"os"
"path/filepath"
"strings"
"sync"
)
var (
errInvalidKey = errors.New("invalid key")
errNoSuchKey = errors.New("no such key")
)
type Database struct {
cache map[string][]byte
basedir string
m sync.RWMutex
}
func NewDatabse(basedir string) (db *Database, err error) {
err = os.Mkdir(basedir, 0755)
if err != nil && !os.IsExist(err) {
return
}
return &Database{
cache: make(map[string][]byte),
basedir: basedir,
}, nil
}
func (db *Database) Set(key string, val []byte) (err error) {
if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) {
return errInvalidKey
}
err = ioutil.WriteFile(filepath.Join(db.basedir, key), val, 0644)
if err != nil {
return
}
db.m.Lock()
db.cache[key] = val
db.m.Unlock()
return
}
func (db *Database) Get(key string) (val []byte, err error) {
if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) {
return nil, errInvalidKey
}
db.m.RLock()
data, ok := db.cache[key]
db.m.RUnlock()
if !ok {
data, err = ioutil.ReadFile(filepath.Join(db.basedir, key))
if err != nil {
err = errNoSuchKey
return nil, err
}
db.m.Lock()
db.cache[key] = data
db.m.Unlock()
}
val = make([]byte, len(data))
copy(val, data)
return
}
func (db *Database) Remove(key string) {
if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) {
return
}
os.Remove(filepath.Join(db.basedir, key))
db.m.Lock()
delete(db.cache, key)
db.m.Unlock()
return
}