Use filesystem based kv store instead of sqlite
This commit is contained in:
parent
3b50f40c08
commit
59aad78f66
2
go.mod
2
go.mod
|
@ -4,7 +4,7 @@ go 1.13
|
|||
|
||||
require (
|
||||
github.com/gorilla/mux v1.7.3
|
||||
github.com/mattn/go-sqlite3 v2.0.1+incompatible
|
||||
github.com/mattn/go-sqlite3 v2.0.2+incompatible // indirect
|
||||
mastodon v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
module web
|
||||
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/gorilla/mux v1.7.3
|
||||
mastodon v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
replace mastodon => ./mastodon
|
4
go.sum
4
go.sum
|
@ -2,7 +2,7 @@ github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
|
|||
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
|
||||
github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U=
|
||||
github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y=
|
||||
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE=
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
package kv
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidKey = errors.New("invalid key")
|
||||
errNoSuchKey = errors.New("no such key")
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
data map[string][]byte
|
||||
basedir string
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
func NewDatabse(basedir string) (db *Database, err error) {
|
||||
err = os.Mkdir(basedir, 0755)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
return
|
||||
}
|
||||
|
||||
return &Database{
|
||||
data: make(map[string][]byte),
|
||||
basedir: basedir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *Database) Set(key string, val []byte) (err error) {
|
||||
if len(key) < 1 {
|
||||
return errInvalidKey
|
||||
}
|
||||
|
||||
db.m.Lock()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
delete(db.data, key)
|
||||
}
|
||||
db.m.Unlock()
|
||||
}()
|
||||
|
||||
db.data[key] = val
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(db.basedir, key), val, 0644)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (db *Database) Get(key string) (val []byte, err error) {
|
||||
if len(key) < 1 {
|
||||
return nil, errInvalidKey
|
||||
}
|
||||
|
||||
db.m.RLock()
|
||||
defer db.m.RUnlock()
|
||||
|
||||
data, ok := db.data[key]
|
||||
if !ok {
|
||||
data, err = ioutil.ReadFile(filepath.Join(db.basedir, key))
|
||||
if err != nil {
|
||||
err = errNoSuchKey
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.data[key] = data
|
||||
}
|
||||
|
||||
val = make([]byte, len(data))
|
||||
copy(val, data)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (db *Database) Remove(key string) {
|
||||
if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) {
|
||||
return
|
||||
}
|
||||
|
||||
db.m.Lock()
|
||||
defer db.m.Unlock()
|
||||
|
||||
delete(db.data, key)
|
||||
os.Remove(filepath.Join(db.basedir, key))
|
||||
|
||||
return
|
||||
}
|
17
main.go
17
main.go
|
@ -1,19 +1,18 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"web/config"
|
||||
"web/kv"
|
||||
"web/renderer"
|
||||
"web/repository"
|
||||
"web/service"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -35,22 +34,24 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", config.DatabasePath)
|
||||
if err != nil {
|
||||
err = os.Mkdir(config.DatabasePath, 0755)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
sessionRepo, err := repository.NewSessionRepository(db)
|
||||
sessionDB, err := kv.NewDatabse(filepath.Join(config.DatabasePath, "session"))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
appRepo, err := repository.NewAppRepository(db)
|
||||
appDB, err := kv.NewDatabse(filepath.Join(config.DatabasePath, "app"))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
sessionRepo := repository.NewSessionRepository(sessionDB)
|
||||
appRepo := repository.NewAppRepository(appDB)
|
||||
|
||||
var logger *log.Logger
|
||||
if len(config.Logfile) < 1 {
|
||||
logger = log.New(os.Stdout, "", log.LstdFlags)
|
||||
|
|
33
model/app.go
33
model/app.go
|
@ -1,19 +1,40 @@
|
|||
package model
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAppNotFound = errors.New("app not found")
|
||||
)
|
||||
|
||||
type App struct {
|
||||
InstanceURL string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
InstanceDomain string
|
||||
InstanceURL string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
type AppRepository interface {
|
||||
Add(app App) (err error)
|
||||
Update(instanceURL string, clientID string, clientSecret string) (err error)
|
||||
Get(instanceURL string) (app App, err error)
|
||||
Get(instanceDomain string) (app App, err error)
|
||||
}
|
||||
|
||||
func (a *App) Marshal() []byte {
|
||||
str := a.InstanceURL + "\n" + a.ClientID + "\n" + a.ClientSecret
|
||||
return []byte(str)
|
||||
}
|
||||
|
||||
func (a *App) Unmarshal(instanceDomain string, data []byte) error {
|
||||
str := string(data)
|
||||
lines := strings.Split(str, "\n")
|
||||
if len(lines) != 3 {
|
||||
return errors.New("invalid data")
|
||||
}
|
||||
a.InstanceDomain = instanceDomain
|
||||
a.InstanceURL = lines[0]
|
||||
a.ClientID = lines[1]
|
||||
a.ClientSecret = lines[2]
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
package model
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string
|
||||
InstanceURL string
|
||||
AccessToken string
|
||||
ID string
|
||||
InstanceDomain string
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
type SessionRepository interface {
|
||||
|
@ -21,3 +24,26 @@ type SessionRepository interface {
|
|||
func (s Session) IsLoggedIn() bool {
|
||||
return len(s.AccessToken) > 0
|
||||
}
|
||||
|
||||
func (s *Session) Marshal() []byte {
|
||||
str := s.InstanceDomain + "\n" + s.AccessToken
|
||||
return []byte(str)
|
||||
}
|
||||
|
||||
func (s *Session) Unmarshal(id string, data []byte) error {
|
||||
str := string(data)
|
||||
lines := strings.Split(str, "\n")
|
||||
|
||||
size := len(lines)
|
||||
if size == 1 {
|
||||
s.InstanceDomain = lines[0]
|
||||
} else if size == 2 {
|
||||
s.InstanceDomain = lines[0]
|
||||
s.AccessToken = lines[1]
|
||||
} else {
|
||||
return errors.New("invalid data")
|
||||
}
|
||||
|
||||
s.ID = id
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,54 +1,33 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"web/kv"
|
||||
"web/model"
|
||||
)
|
||||
|
||||
type appRepository struct {
|
||||
db *sql.DB
|
||||
db *kv.Database
|
||||
}
|
||||
|
||||
func NewAppRepository(db *sql.DB) (*appRepository, error) {
|
||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS app
|
||||
(instance_url varchar, client_id varchar, client_secret varchar)`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func NewAppRepository(db *kv.Database) *appRepository {
|
||||
return &appRepository{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *appRepository) Add(a model.App) (err error) {
|
||||
_, err = repo.db.Exec("INSERT INTO app VALUES (?, ?, ?)", a.InstanceURL, a.ClientID, a.ClientSecret)
|
||||
err = repo.db.Set(a.InstanceDomain, a.Marshal())
|
||||
return
|
||||
}
|
||||
|
||||
func (repo *appRepository) Update(instanceURL string, clientID string, clientSecret string) (err error) {
|
||||
_, err = repo.db.Exec("UPDATE app SET client_id = ?, client_secret = ? where instance_url = ?", clientID, clientSecret, instanceURL)
|
||||
return
|
||||
}
|
||||
|
||||
func (repo *appRepository) Get(instanceURL string) (a model.App, err error) {
|
||||
rows, err := repo.db.Query("SELECT * FROM app WHERE instance_url = ?", instanceURL)
|
||||
func (repo *appRepository) Get(instanceDomain string) (a model.App, err error) {
|
||||
data, err := repo.db.Get(instanceDomain)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
err = model.ErrAppNotFound
|
||||
return
|
||||
}
|
||||
|
||||
err = rows.Scan(&a.InstanceURL, &a.ClientID, &a.ClientSecret)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = a.Unmarshal(instanceDomain, data)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,54 +1,50 @@
|
|||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"web/kv"
|
||||
"web/model"
|
||||
)
|
||||
|
||||
type sessionRepository struct {
|
||||
db *sql.DB
|
||||
db *kv.Database
|
||||
}
|
||||
|
||||
func NewSessionRepository(db *sql.DB) (*sessionRepository, error) {
|
||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS session
|
||||
(id varchar, instance_url varchar, access_token varchar)`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func NewSessionRepository(db *kv.Database) *sessionRepository {
|
||||
return &sessionRepository{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *sessionRepository) Add(s model.Session) (err error) {
|
||||
_, err = repo.db.Exec("INSERT INTO session VALUES (?, ?, ?)", s.ID, s.InstanceURL, s.AccessToken)
|
||||
err = repo.db.Set(s.ID, s.Marshal())
|
||||
return
|
||||
}
|
||||
|
||||
func (repo *sessionRepository) Update(sessionID string, accessToken string) (err error) {
|
||||
_, err = repo.db.Exec("UPDATE session SET access_token = ? where id = ?", accessToken, sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
func (repo *sessionRepository) Get(id string) (s model.Session, err error) {
|
||||
rows, err := repo.db.Query("SELECT * FROM session WHERE id = ?", id)
|
||||
func (repo *sessionRepository) Update(id string, accessToken string) (err error) {
|
||||
data, err := repo.db.Get(id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
var s model.Session
|
||||
err = s.Unmarshal(id, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.AccessToken = accessToken
|
||||
|
||||
return repo.db.Set(id, s.Marshal())
|
||||
}
|
||||
|
||||
func (repo *sessionRepository) Get(id string) (s model.Session, err error) {
|
||||
data, err := repo.db.Get(id)
|
||||
if err != nil {
|
||||
err = model.ErrSessionNotFound
|
||||
return
|
||||
}
|
||||
|
||||
err = rows.Scan(&s.ID, &s.InstanceURL, &s.AccessToken)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = s.Unmarshal(id, data)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -40,12 +40,12 @@ func (s *authService) getClient(ctx context.Context) (c *mastodon.Client, err er
|
|||
if err != nil {
|
||||
return nil, ErrInvalidSession
|
||||
}
|
||||
client, err := s.appRepo.Get(session.InstanceURL)
|
||||
client, err := s.appRepo.Get(session.InstanceDomain)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c = mastodon.NewClient(&mastodon.Config{
|
||||
Server: session.InstanceURL,
|
||||
Server: client.InstanceURL,
|
||||
ClientID: client.ClientID,
|
||||
ClientSecret: client.ClientSecret,
|
||||
AccessToken: session.AccessToken,
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"mastodon"
|
||||
|
@ -64,14 +63,18 @@ func NewService(clientName string, clientScope string, clientWebsite string,
|
|||
|
||||
func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
|
||||
redirectUrl string, sessionID string, err error) {
|
||||
if !strings.HasPrefix(instance, "https://") {
|
||||
instance = "https://" + instance
|
||||
var instanceURL string
|
||||
if strings.HasPrefix(instance, "https://") {
|
||||
instanceURL = instance
|
||||
instance = strings.TrimPrefix(instance, "https://")
|
||||
} else {
|
||||
instanceURL = "https://" + instance
|
||||
}
|
||||
|
||||
sessionID = util.NewSessionId()
|
||||
err = svc.sessionRepo.Add(model.Session{
|
||||
ID: sessionID,
|
||||
InstanceURL: instance,
|
||||
ID: sessionID,
|
||||
InstanceDomain: instance,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -85,7 +88,7 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
|
|||
|
||||
var mastoApp *mastodon.Application
|
||||
mastoApp, err = mastodon.RegisterApp(ctx, &mastodon.AppConfig{
|
||||
Server: instance,
|
||||
Server: instanceURL,
|
||||
ClientName: svc.clientName,
|
||||
Scopes: svc.clientScope,
|
||||
Website: svc.clientWebsite,
|
||||
|
@ -96,9 +99,10 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
|
|||
}
|
||||
|
||||
app = model.App{
|
||||
InstanceURL: instance,
|
||||
ClientID: mastoApp.ClientID,
|
||||
ClientSecret: mastoApp.ClientSecret,
|
||||
InstanceDomain: instance,
|
||||
InstanceURL: instanceURL,
|
||||
ClientID: mastoApp.ClientID,
|
||||
ClientSecret: mastoApp.ClientSecret,
|
||||
}
|
||||
|
||||
err = svc.appRepo.Add(app)
|
||||
|
@ -136,7 +140,7 @@ func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *masto
|
|||
return
|
||||
}
|
||||
|
||||
app, err := svc.appRepo.Get(session.InstanceURL)
|
||||
app, err := svc.appRepo.Get(session.InstanceDomain)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue