mirror of
https://github.com/lejianwen/rustdesk-api.git
synced 2026-02-17 14:04:51 +08:00
first
This commit is contained in:
71
lib/cache/cache.go
vendored
Normal file
71
lib/cache/cache.go
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
Get(key string, value interface{}) error
|
||||
Set(key string, value interface{}, exp int) error
|
||||
Gc() error
|
||||
}
|
||||
|
||||
// MaxTimeOut 最大超时时间
|
||||
|
||||
const (
|
||||
TypeMem = "memory"
|
||||
TypeRedis = "redis"
|
||||
TypeFile = "file"
|
||||
MaxTimeOut = 365 * 24 * 3600
|
||||
)
|
||||
|
||||
func New(typ string) Handler {
|
||||
var cache Handler
|
||||
switch typ {
|
||||
case TypeFile:
|
||||
cache = NewFileCache()
|
||||
case TypeRedis:
|
||||
cache = new(RedisCache)
|
||||
case TypeMem: // memory
|
||||
cache = NewMemoryCache(0)
|
||||
default:
|
||||
cache = NewMemoryCache(0)
|
||||
}
|
||||
return cache
|
||||
}
|
||||
|
||||
func EncodeValue(value interface{}) (string, error) {
|
||||
/*if v, ok := value.(string); ok {
|
||||
return v, nil
|
||||
}
|
||||
if v, ok := value.([]byte); ok {
|
||||
return string(v), nil
|
||||
}*/
|
||||
b, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func DecodeValue(value string, rtv interface{}) error {
|
||||
//判断rtv的类型是否是string,如果是string,直接赋值并返回
|
||||
/*switch rtv.(type) {
|
||||
case *string:
|
||||
*(rtv.(*string)) = value
|
||||
return nil
|
||||
case *[]byte:
|
||||
*(rtv.(*[]byte)) = []byte(value)
|
||||
return nil
|
||||
//struct
|
||||
case *interface{}:
|
||||
err := json.Unmarshal(([]byte)(value), rtv)
|
||||
return err
|
||||
default:
|
||||
err := json.Unmarshal(([]byte)(value), rtv)
|
||||
return err
|
||||
}
|
||||
*/
|
||||
err := json.Unmarshal(([]byte)(value), rtv)
|
||||
return err
|
||||
}
|
||||
92
lib/cache/cache_test.go
vendored
Normal file
92
lib/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSimpleCache(t *testing.T) {
|
||||
|
||||
type st struct {
|
||||
A string
|
||||
B string
|
||||
}
|
||||
|
||||
items := map[string]interface{}{}
|
||||
items["a"] = "b"
|
||||
items["b"] = "c"
|
||||
|
||||
ab := &st{
|
||||
A: "a",
|
||||
B: "b",
|
||||
}
|
||||
items["ab"] = *ab
|
||||
|
||||
a := items["a"]
|
||||
fmt.Println(a)
|
||||
|
||||
b := items["b"]
|
||||
fmt.Println(b)
|
||||
|
||||
ab.A = "aa"
|
||||
ab2 := st{}
|
||||
ab2 = (items["ab"]).(st)
|
||||
fmt.Println(ab2, reflect.TypeOf(ab2))
|
||||
|
||||
}
|
||||
|
||||
func TestFileCacheSet(t *testing.T) {
|
||||
fc := New("file")
|
||||
err := fc.Set("123", "ddd", 0)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileCacheGet(t *testing.T) {
|
||||
fc := New("file")
|
||||
err := fc.Set("123", "45156", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
res := ""
|
||||
err = fc.Get("123", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
fmt.Println("res", res)
|
||||
}
|
||||
|
||||
func TestRedisCacheSet(t *testing.T) {
|
||||
rc := NewRedis(&redis.Options{
|
||||
Addr: "192.168.1.168:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
})
|
||||
err := rc.Set("123", "ddd", 0)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisCacheGet(t *testing.T) {
|
||||
rc := NewRedis(&redis.Options{
|
||||
Addr: "192.168.1.168:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
})
|
||||
err := rc.Set("123", "451156", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
res := ""
|
||||
err = rc.Get("123", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
fmt.Println("res", res)
|
||||
}
|
||||
103
lib/cache/file.go
vendored
Normal file
103
lib/cache/file.go
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FileCache struct {
|
||||
mu sync.Mutex
|
||||
locks map[string]*sync.Mutex
|
||||
Dir string
|
||||
}
|
||||
|
||||
func (fc *FileCache) getLock(key string) *sync.Mutex {
|
||||
fc.mu.Lock()
|
||||
defer fc.mu.Unlock()
|
||||
if fc.locks == nil {
|
||||
fc.locks = make(map[string]*sync.Mutex)
|
||||
}
|
||||
if _, ok := fc.locks[key]; !ok {
|
||||
fc.locks[key] = new(sync.Mutex)
|
||||
}
|
||||
return fc.locks[key]
|
||||
}
|
||||
|
||||
func (c *FileCache) Get(key string, value interface{}) error {
|
||||
data, _ := c.getValue(key)
|
||||
err := DecodeValue(data, value)
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取值,如果文件不存在或者过期,返回空,过滤掉错误
|
||||
func (c *FileCache) getValue(key string) (string, error) {
|
||||
f := c.fileName(key)
|
||||
fileInfo, err := os.Stat(f)
|
||||
if err != nil {
|
||||
//文件不存在
|
||||
return "", nil
|
||||
}
|
||||
difT := time.Now().Sub(fileInfo.ModTime())
|
||||
if difT >= 0 {
|
||||
os.Remove(f)
|
||||
return "", nil
|
||||
}
|
||||
data, err := os.ReadFile(f)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// 保存值
|
||||
func (c *FileCache) saveValue(key string, value string, exp int) error {
|
||||
f := c.fileName(key)
|
||||
lock := c.getLock(f)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err := os.WriteFile(f, ([]byte)(value), 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exp <= 0 {
|
||||
exp = MaxTimeOut
|
||||
}
|
||||
expFromNow := time.Now().Add(time.Duration(exp) * time.Second)
|
||||
err = os.Chtimes(f, expFromNow, expFromNow)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *FileCache) Set(key string, value interface{}, exp int) error {
|
||||
str, err := EncodeValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.saveValue(key, str, exp)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *FileCache) SetDir(path string) {
|
||||
c.Dir = path
|
||||
}
|
||||
|
||||
func (c *FileCache) fileName(key string) string {
|
||||
f := c.Dir + string(os.PathSeparator) + fmt.Sprintf("%x", md5.Sum([]byte(key)))
|
||||
return f
|
||||
}
|
||||
|
||||
func (c *FileCache) Gc() error {
|
||||
//检查文件过期时间,并删除
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewFileCache() *FileCache {
|
||||
return &FileCache{
|
||||
locks: make(map[string]*sync.Mutex),
|
||||
Dir: os.TempDir(),
|
||||
}
|
||||
}
|
||||
94
lib/cache/file_test.go
vendored
Normal file
94
lib/cache/file_test.go
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFileSet(t *testing.T) {
|
||||
fc := NewFileCache()
|
||||
err := fc.Set("123", "ddd", 0)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileGet(t *testing.T) {
|
||||
fc := NewFileCache()
|
||||
res := ""
|
||||
err := fc.Get("123", &res)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
fmt.Println("res", res)
|
||||
}
|
||||
func TestFileSetGet(t *testing.T) {
|
||||
fc := NewFileCache()
|
||||
err := fc.Set("key1", "ddd", 0)
|
||||
res := ""
|
||||
err = fc.Get("key1", &res)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
fmt.Println("res", res)
|
||||
}
|
||||
func TestFileGetJson(t *testing.T) {
|
||||
fc := NewFileCache()
|
||||
old := &r{
|
||||
A: "a", B: "b",
|
||||
}
|
||||
fc.Set("123", old, 0)
|
||||
res := &r{}
|
||||
err2 := fc.Get("123", res)
|
||||
fmt.Println("res", res)
|
||||
if err2 != nil {
|
||||
t.Fatalf("读取失败" + err2.Error())
|
||||
}
|
||||
}
|
||||
func TestFileSetGetJson(t *testing.T) {
|
||||
fc := NewFileCache()
|
||||
|
||||
old_rr := &rr{AA: "aa", BB: "bb"}
|
||||
old := &r{
|
||||
A: "a", B: "b",
|
||||
R: old_rr,
|
||||
}
|
||||
err := fc.Set("123", old, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
//old_rr.AA = "aaa"
|
||||
fmt.Println("old_rr", old)
|
||||
|
||||
res := &r{}
|
||||
err2 := fc.Get("123", res)
|
||||
fmt.Println("res", res)
|
||||
if err2 != nil {
|
||||
t.Fatalf("读取失败" + err2.Error())
|
||||
}
|
||||
if !reflect.DeepEqual(res, old) {
|
||||
t.Fatalf("读取错误")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkSet(b *testing.B) {
|
||||
fc := NewFileCache()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
fc.Set("123", "{dsv}", 1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGet(b *testing.B) {
|
||||
fc := NewFileCache()
|
||||
b.ResetTimer()
|
||||
v := ""
|
||||
for i := 0; i < b.N; i++ {
|
||||
fc.Get("123", &v)
|
||||
}
|
||||
}
|
||||
215
lib/cache/memory.go
vendored
Normal file
215
lib/cache/memory.go
vendored
Normal file
@@ -0,0 +1,215 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"container/list"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MemoryCache struct {
|
||||
data map[string]*CacheItem
|
||||
ll *list.List // 用于实现LRU
|
||||
pq PriorityQueue // 用于实现TTL
|
||||
quit chan struct{}
|
||||
mu sync.Mutex
|
||||
maxBytes int64
|
||||
usedBytes int64
|
||||
}
|
||||
|
||||
type CacheItem struct {
|
||||
Key string
|
||||
Value string
|
||||
Expiration int64
|
||||
Index int
|
||||
ListEle *list.Element
|
||||
}
|
||||
|
||||
type PriorityQueue []*CacheItem
|
||||
|
||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||
|
||||
func (pq PriorityQueue) Less(i, j int) bool {
|
||||
return pq[i].Expiration < pq[j].Expiration
|
||||
}
|
||||
|
||||
func (pq PriorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
pq[i].Index = i
|
||||
pq[j].Index = j
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Push(x interface{}) {
|
||||
item := x.(*CacheItem)
|
||||
item.Index = len(*pq)
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *PriorityQueue) Pop() interface{} {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
old[n-1] = nil // avoid memory leak
|
||||
item.Index = -1 // for safety
|
||||
*pq = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
func (m *MemoryCache) Get(key string, value interface{}) error {
|
||||
// 使用反射将存储的值设置到传入的指针变量中
|
||||
val := reflect.ValueOf(value)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
return errors.New("value must be a pointer")
|
||||
}
|
||||
//设为空值
|
||||
val.Elem().Set(reflect.Zero(val.Elem().Type()))
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if item, ok := m.data[key]; ok {
|
||||
if item.Expiration < time.Now().UnixNano() {
|
||||
m.deleteItem(item)
|
||||
return nil
|
||||
}
|
||||
//移动到队列尾部
|
||||
m.ll.MoveToBack(item.ListEle)
|
||||
|
||||
err := DecodeValue(item.Value, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryCache) Set(key string, value interface{}, exp int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
v, err := EncodeValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//key 所占用的内存
|
||||
keyBytes := int64(len(key))
|
||||
//value所占用的内存空间大小
|
||||
valueBytes := int64(len(v))
|
||||
//判断是否超过最大内存限制
|
||||
if m.maxBytes != 0 && m.maxBytes < keyBytes+valueBytes {
|
||||
return errors.New("exceed maxBytes")
|
||||
}
|
||||
m.usedBytes += keyBytes + valueBytes
|
||||
if m.maxBytes != 0 && m.usedBytes > m.maxBytes {
|
||||
m.RemoveOldest()
|
||||
}
|
||||
if exp <= 0 {
|
||||
exp = MaxTimeOut
|
||||
}
|
||||
expiration := time.Now().Add(time.Duration(exp) * time.Second).UnixNano()
|
||||
item, exists := m.data[key]
|
||||
if exists {
|
||||
item.Value = v
|
||||
item.Expiration = expiration
|
||||
heap.Fix(&m.pq, item.Index)
|
||||
m.ll.MoveToBack(item.ListEle)
|
||||
} else {
|
||||
ele := m.ll.PushBack(key)
|
||||
item = &CacheItem{
|
||||
Key: key,
|
||||
Value: v,
|
||||
Expiration: expiration,
|
||||
ListEle: ele,
|
||||
}
|
||||
m.data[key] = item
|
||||
heap.Push(&m.pq, item)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryCache) RemoveOldest() {
|
||||
for m.maxBytes != 0 && m.usedBytes > m.maxBytes {
|
||||
elem := m.ll.Front()
|
||||
if elem != nil {
|
||||
key := elem.Value.(string)
|
||||
item := m.data[key]
|
||||
m.deleteItem(item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// evictExpiredItems removes all expired items from the cache.
|
||||
func (m *MemoryCache) evictExpiredItems() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
now := time.Now().UnixNano()
|
||||
for m.pq.Len() > 0 {
|
||||
item := m.pq[0]
|
||||
if item.Expiration > now {
|
||||
break
|
||||
}
|
||||
m.deleteItem(item)
|
||||
}
|
||||
}
|
||||
|
||||
// startEviction starts a goroutine that evicts expired items from the cache.
|
||||
func (m *MemoryCache) startEviction() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.evictExpiredItems()
|
||||
case <-m.quit:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// stopEviction 停止定时清理
|
||||
func (m *MemoryCache) stopEviction() {
|
||||
close(m.quit)
|
||||
}
|
||||
|
||||
// deleteItem removes a key from the cache.
|
||||
func (m *MemoryCache) deleteItem(item *CacheItem) {
|
||||
m.ll.Remove(item.ListEle)
|
||||
m.usedBytes -= int64(len(item.Key)) + int64(len(item.Value))
|
||||
heap.Remove(&m.pq, item.Index)
|
||||
delete(m.data, item.Key)
|
||||
}
|
||||
|
||||
func (m *MemoryCache) Gc() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data = make(map[string]*CacheItem)
|
||||
m.ll = list.New()
|
||||
m.pq = make(PriorityQueue, 0)
|
||||
heap.Init(&m.pq)
|
||||
m.usedBytes = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewMemoryCache creates a new MemoryCache.default maxBytes is 0, means no limit.
|
||||
func NewMemoryCache(maxBytes int64) *MemoryCache {
|
||||
cache := &MemoryCache{
|
||||
data: make(map[string]*CacheItem),
|
||||
pq: make(PriorityQueue, 0),
|
||||
quit: make(chan struct{}),
|
||||
ll: list.New(),
|
||||
maxBytes: maxBytes,
|
||||
}
|
||||
heap.Init(&cache.pq)
|
||||
cache.startEviction()
|
||||
return cache
|
||||
}
|
||||
107
lib/cache/memory_test.go
vendored
Normal file
107
lib/cache/memory_test.go
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMemorySet(t *testing.T) {
|
||||
mc := NewMemoryCache(0)
|
||||
err := mc.Set("123", "44567", 0)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryGet(t *testing.T) {
|
||||
mc := NewMemoryCache(0)
|
||||
mc.Set("123", "44567", 0)
|
||||
res := ""
|
||||
err := mc.Get("123", &res)
|
||||
fmt.Println("res", res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败 " + err.Error())
|
||||
}
|
||||
if res != "44567" {
|
||||
t.Fatalf("读取错误")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMemorySetExpGet(t *testing.T) {
|
||||
mc := NewMemoryCache(0)
|
||||
//mc.stopEviction()
|
||||
mc.Set("1", "10", 10)
|
||||
mc.Set("2", "5", 5)
|
||||
err := mc.Set("3", "3", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
|
||||
res := ""
|
||||
err = mc.Get("3", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败" + err.Error())
|
||||
}
|
||||
fmt.Println("res 3", res)
|
||||
time.Sleep(4 * time.Second)
|
||||
//res = ""
|
||||
err = mc.Get("3", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败" + err.Error())
|
||||
}
|
||||
fmt.Println("res 3", res)
|
||||
err = mc.Get("2", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败" + err.Error())
|
||||
}
|
||||
fmt.Println("res 2", res)
|
||||
err = mc.Get("1", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败" + err.Error())
|
||||
}
|
||||
fmt.Println("res 1", res)
|
||||
|
||||
}
|
||||
func TestMemoryLru(t *testing.T) {
|
||||
mc := NewMemoryCache(18)
|
||||
mc.Set("1", "1111", 10)
|
||||
mc.Set("2", "2222", 5)
|
||||
//读取一次,2就会被放到最后
|
||||
mc.Get("1", nil)
|
||||
err := mc.Set("3", "三", 3)
|
||||
if err != nil {
|
||||
//t.Fatalf("写入失败")
|
||||
}
|
||||
|
||||
res := ""
|
||||
err = mc.Get("3", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败" + err.Error())
|
||||
}
|
||||
fmt.Println("res3", res)
|
||||
res = ""
|
||||
err = mc.Get("2", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败" + err.Error())
|
||||
}
|
||||
fmt.Println("res2", res)
|
||||
res = ""
|
||||
err = mc.Get("1", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败" + err.Error())
|
||||
}
|
||||
fmt.Println("res1", res)
|
||||
|
||||
}
|
||||
func BenchmarkMemorySet(b *testing.B) {
|
||||
mc := NewMemoryCache(0)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
value := fmt.Sprintf("value%d", i)
|
||||
mc.Set(key, value, 1000)
|
||||
}
|
||||
}
|
||||
49
lib/cache/redis.go
vendored
Normal file
49
lib/cache/redis.go
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
type RedisCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func RedisCacheInit(conf *redis.Options) *RedisCache {
|
||||
c := &RedisCache{}
|
||||
c.rdb = redis.NewClient(conf)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *RedisCache) Get(key string, value interface{}) error {
|
||||
data, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err1 := DecodeValue(data, value)
|
||||
return err1
|
||||
}
|
||||
|
||||
func (c *RedisCache) Set(key string, value interface{}, exp int) error {
|
||||
str, err := EncodeValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exp <= 0 {
|
||||
exp = MaxTimeOut
|
||||
}
|
||||
_, err1 := c.rdb.Set(ctx, key, str, time.Duration(exp)*time.Second).Result()
|
||||
return err1
|
||||
}
|
||||
|
||||
func (c *RedisCache) Gc() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRedis(conf *redis.Options) *RedisCache {
|
||||
cache := RedisCacheInit(conf)
|
||||
return cache
|
||||
}
|
||||
94
lib/cache/redis_test.go
vendored
Normal file
94
lib/cache/redis_test.go
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRedisSet(t *testing.T) {
|
||||
//rc := New("redis")
|
||||
rc := RedisCacheInit(&redis.Options{
|
||||
Addr: "192.168.1.168:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
})
|
||||
err := rc.Set("123", "ddd", 0)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisGet(t *testing.T) {
|
||||
rc := RedisCacheInit(&redis.Options{
|
||||
Addr: "192.168.1.168:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
})
|
||||
err := rc.Set("123", "451156", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
res := ""
|
||||
err = rc.Get("123", &res)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
fmt.Println("res", res)
|
||||
}
|
||||
|
||||
func TestRedisGetJson(t *testing.T) {
|
||||
rc := RedisCacheInit(&redis.Options{
|
||||
Addr: "192.168.1.168:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
})
|
||||
type r struct {
|
||||
Aa string `json:"a"`
|
||||
B string `json:"c"`
|
||||
}
|
||||
old := &r{
|
||||
Aa: "ab", B: "cdc",
|
||||
}
|
||||
err := rc.Set("1233", old, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
|
||||
res := &r{}
|
||||
err2 := rc.Get("1233", res)
|
||||
if err2 != nil {
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
if !reflect.DeepEqual(res, old) {
|
||||
t.Fatalf("读取错误")
|
||||
}
|
||||
fmt.Println(res, res.Aa)
|
||||
}
|
||||
|
||||
func BenchmarkRSet(b *testing.B) {
|
||||
rc := RedisCacheInit(&redis.Options{
|
||||
Addr: "192.168.1.168:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
})
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rc.Set("123", "{dsv}", 1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRGet(b *testing.B) {
|
||||
rc := RedisCacheInit(&redis.Options{
|
||||
Addr: "192.168.1.168:6379",
|
||||
Password: "", // no password set
|
||||
DB: 0, // use default DB
|
||||
})
|
||||
b.ResetTimer()
|
||||
v := ""
|
||||
for i := 0; i < b.N; i++ {
|
||||
rc.Get("123", &v)
|
||||
}
|
||||
}
|
||||
65
lib/cache/simple_cache.go
vendored
Normal file
65
lib/cache/simple_cache.go
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// 此处实现了一个简单的缓存,用于测试
|
||||
// SimpleCache is a simple cache implementation
|
||||
type SimpleCache struct {
|
||||
data map[string]interface{}
|
||||
mu sync.Mutex
|
||||
maxBytes int64
|
||||
usedBytes int64
|
||||
}
|
||||
|
||||
func (s *SimpleCache) Get(key string, value interface{}) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 使用反射将存储的值设置到传入的指针变量中
|
||||
val := reflect.ValueOf(value)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
return errors.New("value must be a pointer")
|
||||
}
|
||||
v, ok := s.data[key]
|
||||
if !ok {
|
||||
//设为空值
|
||||
val.Elem().Set(reflect.Zero(val.Elem().Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
vval := reflect.ValueOf(v)
|
||||
if val.Elem().Type() != vval.Type() {
|
||||
//设为空值
|
||||
val.Elem().Set(reflect.Zero(val.Elem().Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
val.Elem().Set(reflect.ValueOf(v))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SimpleCache) Set(key string, value interface{}, exp int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// 检查传入的值是否是指针,如果是则取其值
|
||||
val := reflect.ValueOf(value)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
s.data[key] = val.Interface()
|
||||
return nil
|
||||
}
|
||||
func (s *SimpleCache) Gc() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewSimpleCache() *SimpleCache {
|
||||
return &SimpleCache{
|
||||
data: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
108
lib/cache/simple_cache_test.go
vendored
Normal file
108
lib/cache/simple_cache_test.go
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSimpleCache_Set(t *testing.T) {
|
||||
s := NewSimpleCache()
|
||||
err := s.Set("key", "value", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
err = s.Set("key", 111, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleCache_Get(t *testing.T) {
|
||||
s := NewSimpleCache()
|
||||
err := s.Set("key", "value", 0)
|
||||
value := ""
|
||||
err = s.Get("key", &value)
|
||||
fmt.Println("value", value)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
|
||||
err = s.Set("key1", 11, 0)
|
||||
value1 := 0
|
||||
err = s.Get("key1", &value1)
|
||||
fmt.Println("value1", value1)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
|
||||
err = s.Set("key2", []byte{'a', 'b'}, 0)
|
||||
value2 := []byte{}
|
||||
err = s.Get("key2", &value2)
|
||||
fmt.Println("value2", string(value2))
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
|
||||
err = s.Set("key3", 33.33, 0)
|
||||
var value3 int
|
||||
err = s.Get("key3", &value3)
|
||||
fmt.Println("value3", value3)
|
||||
if err != nil {
|
||||
t.Fatalf("读取失败")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type r struct {
|
||||
A string `json:"a"`
|
||||
B string `json:"b"`
|
||||
R *rr `json:"r"`
|
||||
}
|
||||
type r2 struct {
|
||||
A string `json:"a"`
|
||||
B string `json:"b"`
|
||||
}
|
||||
type rr struct {
|
||||
AA string `json:"aa"`
|
||||
BB string `json:"bb"`
|
||||
}
|
||||
|
||||
func TestSimpleCache_GetStruct(t *testing.T) {
|
||||
s := NewSimpleCache()
|
||||
|
||||
old_rr := &rr{
|
||||
AA: "aa", BB: "bb",
|
||||
}
|
||||
|
||||
old := &r{
|
||||
A: "ab", B: "cdc",
|
||||
R: old_rr,
|
||||
}
|
||||
err := s.Set("key", old, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("写入失败")
|
||||
}
|
||||
|
||||
res := &r{}
|
||||
err2 := s.Get("key", res)
|
||||
fmt.Println("res", res)
|
||||
if err2 != nil {
|
||||
t.Fatalf("读取失败" + err2.Error())
|
||||
|
||||
}
|
||||
|
||||
//修改原始值,看后面是否会变化
|
||||
old.A = "aa"
|
||||
old_rr.AA = "aaa"
|
||||
fmt.Println("old", old)
|
||||
res2 := &r{}
|
||||
err3 := s.Get("key", res2)
|
||||
fmt.Println("res2", res2, res2.R.AA, res2.R.BB)
|
||||
if err3 != nil {
|
||||
t.Fatalf("读取失败" + err3.Error())
|
||||
|
||||
}
|
||||
//if reflect.DeepEqual(res, old) {
|
||||
// t.Fatalf("读取错误")
|
||||
//}
|
||||
}
|
||||
61
lib/jwt/jwt.go
Normal file
61
lib/jwt/jwt.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Jwt struct {
|
||||
privateKey *rsa.PrivateKey
|
||||
TokenExpireDuration time.Duration
|
||||
}
|
||||
|
||||
type UserClaims struct {
|
||||
UserId uint `json:"user_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func NewJwt(privateKeyFile string, tokenExpireDuration time.Duration) *Jwt {
|
||||
privateKeyContent, err := os.ReadFile(privateKeyFile)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyContent)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &Jwt{
|
||||
privateKey: privateKey,
|
||||
TokenExpireDuration: tokenExpireDuration,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Jwt) GenerateToken(userId uint) string {
|
||||
t := jwt.NewWithClaims(jwt.SigningMethodRS256,
|
||||
UserClaims{
|
||||
UserId: userId,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.TokenExpireDuration)),
|
||||
},
|
||||
})
|
||||
token, err := t.SignedString(s.privateKey)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *Jwt) ParseToken(tokenString string) (uint, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.privateKey.Public(), nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||
return claims.UserId, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
80
lib/jwt/jwt_test.go
Normal file
80
lib/jwt/jwt_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var pk = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEowIBAAKCAQEAnJpq2Sy91iGW3+EuG4V2ke59tITpGINzht0rO8WiRwu11W4p
|
||||
wakS4K4BbjvmC8YjaxXhKE5LHDw0IXvTdIDN7Fuu4qs9xWXIoK+nC3qWrVBtj/1o
|
||||
RJrYme1NenTXEgPlN1FOU6/9XQGgvb+1MSNqxknYo7183mHACvsIIuSTMEFhUbUw
|
||||
XYVQrCtACUILZ9wIDOEzclIY2ZPMTnL1vkvfj629KwGtAvpEyc96Y/HMSH5/VkiG
|
||||
p6L+k+NSjco9HntAGYTiQkfranvdqxRDUsKS53SbV3QSz1zc0l5OEyZDuxFTL7UC
|
||||
7v0G/HVqz6mLpMje756PG/WEpwa/lADc/8FJ5QIDAQABAoIBAEsqUt6qevOsa55J
|
||||
lrfe92pT7kIXCUqazXiN75Jg6eLv2/b1SVWKsWTmIAmo9mHwWE+t0MRnz+VdgCgS
|
||||
JwxkRnKMDwT87Eky8Xku1h7MWEYXtH7IQqOrLwuyut1r907OT9adT9sbPaDGh0CM
|
||||
I4vSVA2YpELzUFvszyB2HRGiZINkHfdLsNxUKsHJOdXbv82RItwzmCYcZismnR3J
|
||||
P8THn06eoBNtlqwdFziuREOzjNnj6J/3glhR5mu4c4+AJoj0hmVaBDfac3GsQsbP
|
||||
x79QQPrUqH9UZ4szubYHXP0uRi/ARlHQ+GNp6foYIsevC0OtLdau0/ouFlfGkEep
|
||||
3aIV5oECgYEAyyWrNhw+BhNFXsyPzEQ4/mO5ucup3cE/tAAtLiSckoXjmY8K7PQr
|
||||
xfKRCkuM1qpcxtYkbTs35aOdK48gL0NVd50QzrWFrQkQkVnpnJ1lYeVgEL1DmalD
|
||||
B55bwTdShcs0gEoKefZCvmotrmYdSpMGsapqqbZFrysFFzRDyDxnHfcCgYEAxVjA
|
||||
/dXxCEUjYFVC3i833lI/yiycJrhjIeffc6DqpSReuTU+i8Nh3sLiytaSqPFVASDS
|
||||
08K3JwVguMTzDgrYkl365lm50WxcBuNgLkSqA90vE/H6gkRZVkuzOb7T+ZdDxf0s
|
||||
7RH4aqeeOSiOcZ3uC+d53UArJFidETXbgguXkAMCgYA22Ynbx05b15IwYW0mCvmU
|
||||
fhqkdr/7lvT7RdztC4eW7D2itYOOrPKwtKjCrdluEHuSWDlnoMib4UxLeY6IFFcc
|
||||
P7VNCqf4K21kwXEZD0pTX1pLyr5Y2+G0SeaeSbCnXVFknhksCvjEbui8oOehvgbd
|
||||
q5S3E/bGsAfk1wDCLMTuywKBgACHrH0CBhOvm9i2YeeW2N+P+PviAslX1WxR4xe8
|
||||
ZuTqpBZ7Ph/B9pFSlKlWyi4J9+B45hgLfdJtAUV9welXvh0mg3X657TYRab/FVMK
|
||||
fCpmfangDHwtEtBYg7K0AH27GkN92pEIa1JeAN7GbRuBARKnHHyrn3IJiuJw8pX2
|
||||
0gFhAoGBAIquI9sAB2dKEOMW+iQJkLH8Hh8/EWyslow+QJiyIsRe1l9jtkOxC5D3
|
||||
Hj4yO4j5LOWDMTgDcLsZTxbGiTzkNc/HghrNIevDAQdgjJQNl84zDjyyCA4r/MA7
|
||||
bYJTtYj8q6J0EDbRdT9b6hMclyzjNXdx2loJxR0R8WUeL1lDEPq8
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
// 测试token生成
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
jwtService := NewJwt(pk, time.Second*1000)
|
||||
token := jwtService.GenerateToken(1)
|
||||
if token == "" {
|
||||
t.Fatal("token生成失败")
|
||||
}
|
||||
fmt.Println(pk, token)
|
||||
}
|
||||
|
||||
// 测试token解析
|
||||
func TestParseToken(t *testing.T) {
|
||||
jwtService := NewJwt(pk, time.Second*1000)
|
||||
token := jwtService.GenerateToken(999)
|
||||
if token == "" {
|
||||
t.Fatal("token生成失败")
|
||||
}
|
||||
uid, err := jwtService.ParseToken(token)
|
||||
if err != nil {
|
||||
|
||||
t.Fatal("token解析失败", err)
|
||||
}
|
||||
if uid != 999 {
|
||||
t.Fatal("token解析失败")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkJwtService_GenerateToken(b *testing.B) {
|
||||
jwtService := NewJwt(pk, time.Second*1000)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
jwtService.GenerateToken(999)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkJwtService_ParseToken(b *testing.B) {
|
||||
jwtService := NewJwt(pk, time.Second*1000)
|
||||
token := jwtService.GenerateToken(999)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = jwtService.ParseToken(token)
|
||||
}
|
||||
|
||||
}
|
||||
32
lib/lock/local.go
Normal file
32
lib/lock/local.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package lock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Local struct {
|
||||
Locks *sync.Map
|
||||
}
|
||||
|
||||
func (l *Local) Lock(key string) {
|
||||
lock := l.GetLock(key)
|
||||
lock.Lock()
|
||||
}
|
||||
|
||||
func (l *Local) UnLock(key string) {
|
||||
lock, ok := l.Locks.Load(key)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Local) GetLock(key string) *sync.Mutex {
|
||||
lock, _ := l.Locks.LoadOrStore(key, &sync.Mutex{})
|
||||
return lock.(*sync.Mutex)
|
||||
}
|
||||
|
||||
func NewLocal() *Local {
|
||||
return &Local{
|
||||
Locks: &sync.Map{},
|
||||
}
|
||||
}
|
||||
100
lib/lock/local_test.go
Normal file
100
lib/lock/local_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package lock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLocal_GetLock(t *testing.T) {
|
||||
l := NewLocal()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
var l1 *sync.Mutex
|
||||
var l2 *sync.Mutex
|
||||
var l3 *sync.Mutex
|
||||
i := 0
|
||||
go func() {
|
||||
l1 = l.GetLock("key")
|
||||
fmt.Println("l1", l1, i)
|
||||
l1.Lock()
|
||||
fmt.Println("l1", i)
|
||||
i++
|
||||
l1.Unlock()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
l2 = l.GetLock("key")
|
||||
fmt.Println("l2", l2, i)
|
||||
l2.Lock()
|
||||
fmt.Println("l2", i)
|
||||
i++
|
||||
l2.Unlock()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
l3 = l.GetLock("key")
|
||||
fmt.Println("l3", l3, i)
|
||||
l3.Lock()
|
||||
fmt.Println("l3", i)
|
||||
i++
|
||||
l3.Unlock()
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
fmt.Println(l1, l2, l3)
|
||||
fmt.Println(l1 == l2, l2 == l3)
|
||||
fmt.Println(&sync.Mutex{} == &sync.Mutex{})
|
||||
}
|
||||
|
||||
func TestLocal_Lock(t *testing.T) {
|
||||
l := NewLocal()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
i := 0
|
||||
go func() {
|
||||
l.Lock("key")
|
||||
fmt.Println("l1", i)
|
||||
i++
|
||||
l.UnLock("key")
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
l.Lock("key")
|
||||
fmt.Println("l2", i)
|
||||
i++
|
||||
l.UnLock("key")
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
l.Lock("key")
|
||||
fmt.Println("l3", i)
|
||||
i++
|
||||
l.UnLock("key")
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
func TestSyncMap(t *testing.T) {
|
||||
m := sync.Map{}
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
v, ok := m.LoadOrStore("key", 1)
|
||||
fmt.Println(1, v, ok)
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
v, ok := m.LoadOrStore("key", 2)
|
||||
fmt.Println(2, v, ok)
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
v, ok := m.LoadOrStore("key", 3)
|
||||
fmt.Println(3, v, ok)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
9
lib/lock/lock.go
Normal file
9
lib/lock/lock.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package lock
|
||||
|
||||
import "sync"
|
||||
|
||||
type Locker interface {
|
||||
GetLock(key string) *sync.Mutex
|
||||
Lock(key string)
|
||||
UnLock(key string)
|
||||
}
|
||||
54
lib/logger/logger.go
Normal file
54
lib/logger/logger.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
nested "github.com/antonfisher/nested-logrus-formatter"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
DebugMode = "debug"
|
||||
ReleaseMode = "release"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Path string
|
||||
Level string
|
||||
ReportCaller bool
|
||||
}
|
||||
|
||||
func New(c *Config) *log.Logger {
|
||||
log.SetFormatter(&nested.Formatter{
|
||||
// HideKeys: true,
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
NoColors: true,
|
||||
NoFieldsColors: true,
|
||||
//FieldsOrder: []string{"name", "age"},
|
||||
})
|
||||
|
||||
// 日志文件
|
||||
f := c.Path
|
||||
var write io.Writer
|
||||
if f != "" {
|
||||
fwriter, err := os.OpenFile(f, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
panic("open log file fail!")
|
||||
}
|
||||
write = io.MultiWriter(fwriter, os.Stdout)
|
||||
} else {
|
||||
write = os.Stdout
|
||||
}
|
||||
|
||||
log.SetOutput(write)
|
||||
|
||||
log.SetReportCaller(c.ReportCaller)
|
||||
|
||||
level, err2 := log.ParseLevel(c.Level)
|
||||
if err2 != nil {
|
||||
level = log.DebugLevel
|
||||
}
|
||||
log.SetLevel(level)
|
||||
|
||||
return log.StandardLogger()
|
||||
}
|
||||
40
lib/orm/mysql.go
Normal file
40
lib/orm/mysql.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type MysqlConfig struct {
|
||||
Dns string
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
}
|
||||
|
||||
func NewMysql(mysqlConf *MysqlConfig) *gorm.DB {
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{
|
||||
DSN: mysqlConf.Dns, // DSN data source name
|
||||
DefaultStringSize: 256, // string 类型字段的默认长度
|
||||
//DisableDatetimePrecision: true, // 禁用 datetime 精度,MySQL 5.6 之前的数据库不支持
|
||||
//DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式,MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
|
||||
//DontSupportRenameColumn: true, // 用 `change` 重命名列,MySQL 8 之前的数据库和 MariaDB 不支持重命名列
|
||||
//SkipInitializeWithVersion: false, // 根据当前 MySQL 版本自动配置
|
||||
}), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
sqlDB, err2 := db.DB()
|
||||
if err2 != nil {
|
||||
fmt.Println(err2)
|
||||
}
|
||||
// SetMaxIdleConns 设置空闲连接池中连接的最大数量
|
||||
sqlDB.SetMaxIdleConns(mysqlConf.MaxIdleConns)
|
||||
|
||||
// SetMaxOpenConns 设置打开数据库连接的最大数量。
|
||||
sqlDB.SetMaxOpenConns(mysqlConf.MaxOpenConns)
|
||||
|
||||
return db
|
||||
}
|
||||
30
lib/orm/sqlite.go
Normal file
30
lib/orm/sqlite.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SqliteConfig struct {
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
}
|
||||
|
||||
func NewSqlite(sqliteConf *SqliteConfig) *gorm.DB {
|
||||
db, err := gorm.Open(sqlite.Open("./data/rustdeskapi.db"), &gorm.Config{})
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
sqlDB, err2 := db.DB()
|
||||
if err2 != nil {
|
||||
fmt.Println(err2)
|
||||
}
|
||||
// SetMaxIdleConns 设置空闲连接池中连接的最大数量
|
||||
sqlDB.SetMaxIdleConns(sqliteConf.MaxIdleConns)
|
||||
|
||||
// SetMaxOpenConns 设置打开数据库连接的最大数量。
|
||||
sqlDB.SetMaxOpenConns(sqliteConf.MaxOpenConns)
|
||||
|
||||
return db
|
||||
}
|
||||
4
lib/upload/local.go
Normal file
4
lib/upload/local.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package upload
|
||||
|
||||
type Local struct {
|
||||
}
|
||||
475
lib/upload/oss.go
Normal file
475
lib/upload/oss.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Oss struct {
|
||||
AccessKeyId string
|
||||
AccessKeySecret string
|
||||
Host string
|
||||
CallbackUrl string
|
||||
ExpireTime int64
|
||||
MaxByte int64
|
||||
}
|
||||
|
||||
const (
|
||||
base64Table = "1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-_"
|
||||
)
|
||||
|
||||
var coder = base64.NewEncoding(base64Table)
|
||||
|
||||
func base64Encode(src []byte) []byte {
|
||||
return []byte(coder.EncodeToString(src))
|
||||
}
|
||||
|
||||
func get_gmt_iso8601(expire_end int64) string {
|
||||
var tokenExpire = time.Unix(expire_end, 0).UTC().Format("2006-01-02T15:04:05Z")
|
||||
return tokenExpire
|
||||
}
|
||||
|
||||
type ConfigStruct struct {
|
||||
Expiration string `json:"expiration"`
|
||||
Conditions [][]interface{} `json:"conditions"`
|
||||
}
|
||||
|
||||
type PolicyToken struct {
|
||||
AccessKeyId string `json:"accessid"`
|
||||
Host string `json:"host"`
|
||||
Expire int64 `json:"expire"`
|
||||
Signature string `json:"signature"`
|
||||
Policy string `json:"policy"`
|
||||
Directory string `json:"dir"`
|
||||
Callback string `json:"callback"`
|
||||
}
|
||||
|
||||
type CallbackParam struct {
|
||||
CallbackUrl string `json:"callbackUrl"`
|
||||
CallbackBody string `json:"callbackBody"`
|
||||
CallbackBodyType string `json:"callbackBodyType"`
|
||||
}
|
||||
|
||||
type CallbackBaseForm struct {
|
||||
Bucket string `json:"bucket" form:"bucket"`
|
||||
Etag string `json:"etag" form:"etag"`
|
||||
Filename string `json:"filename" form:"filename"`
|
||||
Size string `json:"size" form:"size"`
|
||||
MimeType string `json:"mime_type" form:"mime_type"`
|
||||
Height string `json:"height" form:"height"`
|
||||
Width string `json:"width" form:"width"`
|
||||
Format string `json:"format" form:"format"`
|
||||
OriginFilename string `json:"origin_filename" form:"origin_filename"`
|
||||
}
|
||||
|
||||
func (oc *Oss) GetPolicyToken(uploadDir string) string {
|
||||
now := time.Now().Unix()
|
||||
expire_end := now + oc.ExpireTime
|
||||
var tokenExpire = get_gmt_iso8601(expire_end)
|
||||
|
||||
//create post policy json
|
||||
var config ConfigStruct
|
||||
config.Expiration = tokenExpire
|
||||
var condition = []interface{}{"starts-with", "$key", uploadDir}
|
||||
var condition_limit = []interface{}{"content-length-range", 0, oc.MaxByte}
|
||||
config.Conditions = append(config.Conditions, condition, condition_limit)
|
||||
|
||||
//calucate signature
|
||||
result, err := json.Marshal(config)
|
||||
debyte := base64.StdEncoding.EncodeToString(result)
|
||||
h := hmac.New(func() hash.Hash {
|
||||
return sha1.New()
|
||||
}, []byte(oc.AccessKeySecret))
|
||||
io.WriteString(h, debyte)
|
||||
signedStr := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
var callbackParam CallbackParam
|
||||
callbackParam.CallbackUrl = oc.CallbackUrl
|
||||
|
||||
callbackParam.CallbackBody =
|
||||
"bucket=${bucket}&" +
|
||||
"etag=${etag}&" +
|
||||
"filename=${object}&" +
|
||||
"size=${size}&" +
|
||||
"mime_type=${mimeType}&" +
|
||||
"height=${imageInfo.height}&" +
|
||||
"width=${imageInfo.width}&" +
|
||||
"format=${imageInfo.format}&" +
|
||||
"origin_filename=${x:origin_filename}"
|
||||
callbackParam.CallbackBodyType = "application/x-www-form-urlencoded"
|
||||
callback_str, err := json.Marshal(callbackParam)
|
||||
if err != nil {
|
||||
fmt.Println("callback json err:", err)
|
||||
}
|
||||
callbackBase64 := base64.StdEncoding.EncodeToString(callback_str)
|
||||
|
||||
var policyToken PolicyToken
|
||||
policyToken.AccessKeyId = oc.AccessKeyId
|
||||
policyToken.Host = oc.Host
|
||||
policyToken.Expire = expire_end
|
||||
policyToken.Signature = string(signedStr)
|
||||
policyToken.Directory = uploadDir
|
||||
policyToken.Policy = string(debyte)
|
||||
policyToken.Callback = string(callbackBase64)
|
||||
response, err := json.Marshal(policyToken)
|
||||
if err != nil {
|
||||
fmt.Println("json err:", err)
|
||||
}
|
||||
return string(response)
|
||||
}
|
||||
|
||||
func (oc *Oss) Verify(r *http.Request) bool {
|
||||
|
||||
// Get PublicKey bytes
|
||||
bytePublicKey, err := getPublicKey(r)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get Authorization bytes : decode from Base64String
|
||||
byteAuthorization, err := getAuthorization(r)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get MD5 bytes from Newly Constructed Authrization String.
|
||||
byteMD5, err := getMD5FromNewAuthString(r)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// verifySignature and response to client
|
||||
if verifySignature(bytePublicKey, byteMD5, byteAuthorization) {
|
||||
// do something you want accoding to callback_body ...
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// getPublicKey : Get PublicKey bytes from Request.URL
|
||||
func getPublicKey(r *http.Request) ([]byte, error) {
|
||||
var bytePublicKey []byte
|
||||
// get PublicKey URL
|
||||
publicKeyURLBase64 := r.Header.Get("x-oss-pub-key-url")
|
||||
if publicKeyURLBase64 == "" {
|
||||
fmt.Println("GetPublicKey from Request header failed : No x-oss-pub-key-url field. ")
|
||||
return bytePublicKey, errors.New("no x-oss-pub-key-url field in Request header ")
|
||||
}
|
||||
publicKeyURL, _ := base64.StdEncoding.DecodeString(publicKeyURLBase64)
|
||||
// fmt.Printf("publicKeyURL={%s}\n", publicKeyURL)
|
||||
// get PublicKey Content from URL
|
||||
responsePublicKeyURL, err := http.Get(string(publicKeyURL))
|
||||
if err != nil {
|
||||
fmt.Printf("Get PublicKey Content from URL failed : %s \n", err.Error())
|
||||
return bytePublicKey, err
|
||||
}
|
||||
bytePublicKey, err = ioutil.ReadAll(responsePublicKeyURL.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("Read PublicKey Content from URL failed : %s \n", err.Error())
|
||||
return bytePublicKey, err
|
||||
}
|
||||
defer responsePublicKeyURL.Body.Close()
|
||||
// fmt.Printf("publicKey={%s}\n", bytePublicKey)
|
||||
return bytePublicKey, nil
|
||||
}
|
||||
|
||||
// getAuthorization : decode from Base64String
|
||||
func getAuthorization(r *http.Request) ([]byte, error) {
|
||||
var byteAuthorization []byte
|
||||
// Get Authorization bytes : decode from Base64String
|
||||
strAuthorizationBase64 := r.Header.Get("authorization")
|
||||
if strAuthorizationBase64 == "" {
|
||||
fmt.Println("Failed to get authorization field from request header. ")
|
||||
return byteAuthorization, errors.New("no authorization field in Request header")
|
||||
}
|
||||
byteAuthorization, _ = base64.StdEncoding.DecodeString(strAuthorizationBase64)
|
||||
return byteAuthorization, nil
|
||||
}
|
||||
|
||||
// getMD5FromNewAuthString : Get MD5 bytes from Newly Constructed Authrization String.
|
||||
func getMD5FromNewAuthString(r *http.Request) ([]byte, error) {
|
||||
var byteMD5 []byte
|
||||
// Construct the New Auth String from URI+Query+Body
|
||||
bodyContent, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyContent))
|
||||
if err != nil {
|
||||
fmt.Printf("Read Request Body failed : %s \n", err.Error())
|
||||
return byteMD5, err
|
||||
}
|
||||
strCallbackBody := string(bodyContent)
|
||||
// fmt.Printf("r.URL.RawPath={%s}, r.URL.Query()={%s}, strCallbackBody={%s}\n", r.URL.RawPath, r.URL.Query(), strCallbackBody)
|
||||
strURLPathDecode, errUnescape := unescapePath(r.URL.Path, encodePathSegment) //url.PathUnescape(r.URL.Path) for Golang v1.8.2+
|
||||
if errUnescape != nil {
|
||||
fmt.Printf("url.PathUnescape failed : URL.Path=%s, error=%s \n", r.URL.Path, err.Error())
|
||||
return byteMD5, errUnescape
|
||||
}
|
||||
|
||||
// Generate New Auth String prepare for MD5
|
||||
strAuth := ""
|
||||
if r.URL.RawQuery == "" {
|
||||
strAuth = fmt.Sprintf("%s\n%s", strURLPathDecode, strCallbackBody)
|
||||
} else {
|
||||
strAuth = fmt.Sprintf("%s?%s\n%s", strURLPathDecode, r.URL.RawQuery, strCallbackBody)
|
||||
}
|
||||
// fmt.Printf("NewlyConstructedAuthString={%s}\n", strAuth)
|
||||
|
||||
// Generate MD5 from the New Auth String
|
||||
md5Ctx := md5.New()
|
||||
md5Ctx.Write([]byte(strAuth))
|
||||
byteMD5 = md5Ctx.Sum(nil)
|
||||
|
||||
return byteMD5, nil
|
||||
}
|
||||
|
||||
/* VerifySignature
|
||||
* VerifySignature需要三个重要的数据信息来进行签名验证: 1>获取公钥PublicKey; 2>生成新的MD5鉴权串; 3>解码Request携带的鉴权串;
|
||||
* 1>获取公钥PublicKey : 从RequestHeader的"x-oss-pub-key-url"字段中获取 URL, 读取URL链接的包含的公钥内容, 进行解码解析, 将其作为rsa.VerifyPKCS1v15的入参。
|
||||
* 2>生成新的MD5鉴权串 : 把Request中的url中的path部分进行urldecode, 加上url的query部分, 再加上body, 组合之后进行MD5编码, 得到MD5鉴权字节串。
|
||||
* 3>解码Request携带的鉴权串 : 获取RequestHeader的"authorization"字段, 对其进行Base64解码,作为签名验证的鉴权对比串。
|
||||
* rsa.VerifyPKCS1v15进行签名验证,返回验证结果。
|
||||
* */
|
||||
func verifySignature(bytePublicKey []byte, byteMd5 []byte, authorization []byte) bool {
|
||||
pubBlock, _ := pem.Decode(bytePublicKey)
|
||||
if pubBlock == nil {
|
||||
fmt.Printf("Failed to parse PEM block containing the public key")
|
||||
return false
|
||||
}
|
||||
pubInterface, err := x509.ParsePKIXPublicKey(pubBlock.Bytes)
|
||||
if (pubInterface == nil) || (err != nil) {
|
||||
fmt.Printf("x509.ParsePKIXPublicKey(publicKey) failed : %s \n", err.Error())
|
||||
return false
|
||||
}
|
||||
pub := pubInterface.(*rsa.PublicKey)
|
||||
|
||||
errorVerifyPKCS1v15 := rsa.VerifyPKCS1v15(pub, crypto.MD5, byteMd5, authorization)
|
||||
if errorVerifyPKCS1v15 != nil {
|
||||
fmt.Printf("\nSignature Verification is Failed : %s \n", errorVerifyPKCS1v15.Error())
|
||||
//printByteArray(byteMd5, "AuthMd5(fromNewAuthString)")
|
||||
//printByteArray(bytePublicKey, "PublicKeyBase64")
|
||||
//printByteArray(authorization, "AuthorizationFromRequest")
|
||||
return false
|
||||
}
|
||||
|
||||
fmt.Printf("\nSignature Verification is Successful. \n")
|
||||
return true
|
||||
}
|
||||
|
||||
func printByteArray(byteArrary []byte, arrName string) {
|
||||
fmt.Printf("++++++++ printByteArray : ArrayName=%s, ArrayLength=%d \n", arrName, len(byteArrary))
|
||||
for i := 0; i < len(byteArrary); i++ {
|
||||
fmt.Printf("%02x", byteArrary[i])
|
||||
}
|
||||
fmt.Printf("\n-------- printByteArray : End . \n")
|
||||
}
|
||||
|
||||
type EscapeError string
|
||||
|
||||
func (e EscapeError) Error() string {
|
||||
return "invalid URL escape " + strconv.Quote(string(e))
|
||||
}
|
||||
|
||||
type InvalidHostError string
|
||||
|
||||
func (e InvalidHostError) Error() string {
|
||||
return "invalid character " + strconv.Quote(string(e)) + " in host name"
|
||||
}
|
||||
|
||||
type encoding int
|
||||
|
||||
const (
|
||||
encodePath encoding = 1 + iota
|
||||
encodePathSegment
|
||||
encodeHost
|
||||
encodeZone
|
||||
encodeUserPassword
|
||||
encodeQueryComponent
|
||||
encodeFragment
|
||||
)
|
||||
|
||||
// unescapePath : unescapes a string; the mode specifies, which section of the URL string is being unescaped.
|
||||
func unescapePath(s string, mode encoding) (string, error) {
|
||||
// Count %, check that they're well-formed.
|
||||
mode = encodePathSegment
|
||||
n := 0
|
||||
hasPlus := false
|
||||
for i := 0; i < len(s); {
|
||||
switch s[i] {
|
||||
case '%':
|
||||
n++
|
||||
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
|
||||
s = s[i:]
|
||||
if len(s) > 3 {
|
||||
s = s[:3]
|
||||
}
|
||||
return "", EscapeError(s)
|
||||
}
|
||||
// Per https://tools.ietf.org/html/rfc3986#page-21
|
||||
// in the host component %-encoding can only be used
|
||||
// for non-ASCII bytes.
|
||||
// But https://tools.ietf.org/html/rfc6874#section-2
|
||||
// introduces %25 being allowed to escape a percent sign
|
||||
// in IPv6 scoped-address literals. Yay.
|
||||
if mode == encodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" {
|
||||
return "", EscapeError(s[i : i+3])
|
||||
}
|
||||
if mode == encodeZone {
|
||||
// RFC 6874 says basically "anything goes" for zone identifiers
|
||||
// and that even non-ASCII can be redundantly escaped,
|
||||
// but it seems prudent to restrict %-escaped bytes here to those
|
||||
// that are valid host name bytes in their unescaped form.
|
||||
// That is, you can use escaping in the zone identifier but not
|
||||
// to introduce bytes you couldn't just write directly.
|
||||
// But Windows puts spaces here! Yay.
|
||||
v := unhex(s[i+1])<<4 | unhex(s[i+2])
|
||||
if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, encodeHost) {
|
||||
return "", EscapeError(s[i : i+3])
|
||||
}
|
||||
}
|
||||
i += 3
|
||||
case '+':
|
||||
hasPlus = mode == encodeQueryComponent
|
||||
i++
|
||||
default:
|
||||
if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) {
|
||||
return "", InvalidHostError(s[i : i+1])
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
if n == 0 && !hasPlus {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
t := make([]byte, len(s)-2*n)
|
||||
j := 0
|
||||
for i := 0; i < len(s); {
|
||||
switch s[i] {
|
||||
case '%':
|
||||
t[j] = unhex(s[i+1])<<4 | unhex(s[i+2])
|
||||
j++
|
||||
i += 3
|
||||
case '+':
|
||||
if mode == encodeQueryComponent {
|
||||
t[j] = ' '
|
||||
} else {
|
||||
t[j] = '+'
|
||||
}
|
||||
j++
|
||||
i++
|
||||
default:
|
||||
t[j] = s[i]
|
||||
j++
|
||||
i++
|
||||
}
|
||||
}
|
||||
return string(t), nil
|
||||
}
|
||||
|
||||
// Please be informed that for now shouldEscape does not check all
|
||||
// reserved characters correctly. See golang.org/issue/5684.
|
||||
func shouldEscape(c byte, mode encoding) bool {
|
||||
// §2.3 Unreserved characters (alphanum)
|
||||
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' {
|
||||
return false
|
||||
}
|
||||
|
||||
if mode == encodeHost || mode == encodeZone {
|
||||
// §3.2.2 Host allows
|
||||
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "="
|
||||
// as part of reg-name.
|
||||
// We add : because we include :port as part of host.
|
||||
// We add [ ] because we include [ipv6]:port as part of host.
|
||||
// We add < > because they're the only characters left that
|
||||
// we could possibly allow, and Parse will reject them if we
|
||||
// escape them (because hosts can't use %-encoding for
|
||||
// ASCII bytes).
|
||||
switch c {
|
||||
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"':
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
switch c {
|
||||
case '-', '_', '.', '~': // §2.3 Unreserved characters (mark)
|
||||
return false
|
||||
|
||||
case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved)
|
||||
// Different sections of the URL allow a few of
|
||||
// the reserved characters to appear unescaped.
|
||||
switch mode {
|
||||
case encodePath: // §3.3
|
||||
// The RFC allows : @ & = + $ but saves / ; , for assigning
|
||||
// meaning to individual path segments. This package
|
||||
// only manipulates the path as a whole, so we allow those
|
||||
// last three as well. That leaves only ? to escape.
|
||||
return c == '?'
|
||||
|
||||
case encodePathSegment: // §3.3
|
||||
// The RFC allows : @ & = + $ but saves / ; , for assigning
|
||||
// meaning to individual path segments.
|
||||
return c == '/' || c == ';' || c == ',' || c == '?'
|
||||
|
||||
case encodeUserPassword: // §3.2.1
|
||||
// The RFC allows ';', ':', '&', '=', '+', '$', and ',' in
|
||||
// userinfo, so we must escape only '@', '/', and '?'.
|
||||
// The parsing of userinfo treats ':' as special so we must escape
|
||||
// that too.
|
||||
return c == '@' || c == '/' || c == '?' || c == ':'
|
||||
|
||||
case encodeQueryComponent: // §3.4
|
||||
// The RFC reserves (so we must escape) everything.
|
||||
return true
|
||||
|
||||
case encodeFragment: // §4.1
|
||||
// The RFC text is silent but the grammar allows
|
||||
// everything, so escape nothing.
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Everything else must be escaped.
|
||||
return true
|
||||
}
|
||||
|
||||
func ishex(c byte) bool {
|
||||
switch {
|
||||
case '0' <= c && c <= '9':
|
||||
return true
|
||||
case 'a' <= c && c <= 'f':
|
||||
return true
|
||||
case 'A' <= c && c <= 'F':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func unhex(c byte) byte {
|
||||
switch {
|
||||
case '0' <= c && c <= '9':
|
||||
return c - '0'
|
||||
case 'a' <= c && c <= 'f':
|
||||
return c - 'a' + 10
|
||||
case 'A' <= c && c <= 'F':
|
||||
return c - 'A' + 10
|
||||
}
|
||||
return 0
|
||||
}
|
||||
Reference in New Issue
Block a user