package db import ( "bytes" "fmt" "io/ioutil" "os" "strconv" "sync" "testing" ) func newTestLDB() (*LDBDatabase, func()) { dirname, err := ioutil.TempDir(os.TempDir(), "db_test_") if err != nil { panic("failed to create test file: " + err.Error()) } db, err := NewLDBDatabase(dirname, 0, 0) if err != nil { panic("failed to create test database: " + err.Error()) } return db, func() { db.Close() os.RemoveAll(dirname) } } var test_values = []string{"", "a", "1251", "\x00123\x00"} func TestLDB_PutGet(t *testing.T) { db, remove := newTestLDB() defer remove() testPutGet(db, t) } func TestMemoryDB_PutGet(t *testing.T) { testPutGet(NewMemDatabase(), t) } func testPutGet(db Database, t *testing.T) { t.Parallel() for _, k := range test_values { err := db.Put([]byte(k), nil) if err != nil { t.Fatalf("put failed: %v", err) } } for _, k := range test_values { data, err := db.Get([]byte(k)) if err != nil { t.Fatalf("get failed: %v", err) } if len(data) != 0 { t.Fatalf("get returned wrong result, got %q expected nil", string(data)) } } _, err := db.Get([]byte("non-exist-key")) if err == nil { t.Fatalf("expect to return a not found error") } for _, v := range test_values { err := db.Put([]byte(v), []byte(v)) if err != nil { t.Fatalf("put failed: %v", err) } } for _, v := range test_values { data, err := db.Get([]byte(v)) if err != nil { t.Fatalf("get failed: %v", err) } if !bytes.Equal(data, []byte(v)) { t.Fatalf("get returned wrong result, got %q expected %q", string(data), v) } } for _, v := range test_values { err := db.Put([]byte(v), []byte("?")) if err != nil { t.Fatalf("put override failed: %v", err) } } for _, v := range test_values { data, err := db.Get([]byte(v)) if err != nil { t.Fatalf("get failed: %v", err) } if !bytes.Equal(data, []byte("?")) { t.Fatalf("get returned wrong result, got %q expected ?", string(data)) } } for _, v := range test_values { orig, err := db.Get([]byte(v)) if err != nil { t.Fatalf("get failed: %v", err) } orig[0] = byte(0xff) data, err := db.Get([]byte(v)) if err != nil { t.Fatalf("get failed: %v", err) } if !bytes.Equal(data, []byte("?")) { t.Fatalf("get returned wrong result, got %q expected ?", string(data)) } } for _, v := range test_values { err := db.Delete([]byte(v)) if err != nil { t.Fatalf("delete %q failed: %v", v, err) } } for _, v := range test_values { _, err := db.Get([]byte(v)) if err == nil { t.Fatalf("got deleted value %q", v) } } } func TestLDB_ParallelPutGet(t *testing.T) { db, remove := newTestLDB() defer remove() testParallelPutGet(db, t) } func TestMemoryDB_ParallelPutGet(t *testing.T) { testParallelPutGet(NewMemDatabase(), t) } func testParallelPutGet(db Database, t *testing.T) { const n = 8 var pending sync.WaitGroup pending.Add(n) for i := 0; i < n; i++ { go func(key string) { defer pending.Done() err := db.Put([]byte(key), []byte("v"+key)) if err != nil { panic("put failed: " + err.Error()) } }(strconv.Itoa(i)) } pending.Wait() pending.Add(n) for i := 0; i < n; i++ { go func(key string) { defer pending.Done() data, err := db.Get([]byte(key)) if err != nil { panic("get failed: " + err.Error()) } if !bytes.Equal(data, []byte("v"+key)) { panic(fmt.Sprintf("get failed, got %q expected %q", []byte(data), []byte("v"+key))) } }(strconv.Itoa(i)) } pending.Wait() pending.Add(n) for i := 0; i < n; i++ { go func(key string) { defer pending.Done() err := db.Delete([]byte(key)) if err != nil { panic("delete failed: " + err.Error()) } }(strconv.Itoa(i)) } pending.Wait() pending.Add(n) for i := 0; i < n; i++ { go func(key string) { defer pending.Done() _, err := db.Get([]byte(key)) if err == nil { panic("get succeeded") } }(strconv.Itoa(i)) } pending.Wait() }