diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 000000000..c351d9289 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,196 @@ +package db_test + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "strconv" + "sync" + "testing" + + "github.com/simple-rules/harmony-benchmark/db" +) + +func newTestLDB() (*db.LDBDatabase, func()) { + dirname, err := ioutil.TempDir(os.TempDir(), "db_test_") + if err != nil { + panic("failed to create test file: " + err.Error()) + } + db, err := db.NewLDBDatabase(dirname, 0, 0) + if err != nil { + panic("failed to create test database: " + err.Error()) + } + + return db, func() { + db.Close() + os.RemoveAll(dirname) + } +} + +var test_values = []string{"", "a", "1251", "\x00123\x00"} + +func TestLDB_PutGet(t *testing.T) { + db, remove := newTestLDB() + defer remove() + testPutGet(db, t) +} + +func TestMemoryDB_PutGet(t *testing.T) { + testPutGet(db.NewMemDatabase(), t) +} + +func testPutGet(db db.Database, t *testing.T) { + t.Parallel() + + for _, k := range test_values { + err := db.Put([]byte(k), nil) + if err != nil { + t.Fatalf("put failed: %v", err) + } + } + + for _, k := range test_values { + data, err := db.Get([]byte(k)) + if err != nil { + t.Fatalf("get failed: %v", err) + } + if len(data) != 0 { + t.Fatalf("get returned wrong result, got %q expected nil", string(data)) + } + } + + _, err := db.Get([]byte("non-exist-key")) + if err == nil { + t.Fatalf("expect to return a not found error") + } + + for _, v := range test_values { + err := db.Put([]byte(v), []byte(v)) + if err != nil { + t.Fatalf("put failed: %v", err) + } + } + + for _, v := range test_values { + data, err := db.Get([]byte(v)) + if err != nil { + t.Fatalf("get failed: %v", err) + } + if !bytes.Equal(data, []byte(v)) { + t.Fatalf("get returned wrong result, got %q expected %q", string(data), v) + } + } + + for _, v := range test_values { + err := db.Put([]byte(v), []byte("?")) + if err != nil { + t.Fatalf("put override failed: %v", err) + } + } + + for _, v := range test_values { + data, err := db.Get([]byte(v)) + if err != nil { + t.Fatalf("get failed: %v", err) + } + if !bytes.Equal(data, []byte("?")) { + t.Fatalf("get returned wrong result, got %q expected ?", string(data)) + } + } + + for _, v := range test_values { + orig, err := db.Get([]byte(v)) + if err != nil { + t.Fatalf("get failed: %v", err) + } + orig[0] = byte(0xff) + data, err := db.Get([]byte(v)) + if err != nil { + t.Fatalf("get failed: %v", err) + } + if !bytes.Equal(data, []byte("?")) { + t.Fatalf("get returned wrong result, got %q expected ?", string(data)) + } + } + + for _, v := range test_values { + err := db.Delete([]byte(v)) + if err != nil { + t.Fatalf("delete %q failed: %v", v, err) + } + } + + for _, v := range test_values { + _, err := db.Get([]byte(v)) + if err == nil { + t.Fatalf("got deleted value %q", v) + } + } +} + +func TestLDB_ParallelPutGet(t *testing.T) { + db, remove := newTestLDB() + defer remove() + testParallelPutGet(db, t) +} + +func TestMemoryDB_ParallelPutGet(t *testing.T) { + testParallelPutGet(db.NewMemDatabase(), t) +} + +func testParallelPutGet(db db.Database, t *testing.T) { + const n = 8 + var pending sync.WaitGroup + + pending.Add(n) + for i := 0; i < n; i++ { + go func(key string) { + defer pending.Done() + err := db.Put([]byte(key), []byte("v"+key)) + if err != nil { + panic("put failed: " + err.Error()) + } + }(strconv.Itoa(i)) + } + pending.Wait() + + pending.Add(n) + for i := 0; i < n; i++ { + go func(key string) { + defer pending.Done() + data, err := db.Get([]byte(key)) + if err != nil { + panic("get failed: " + err.Error()) + } + if !bytes.Equal(data, []byte("v"+key)) { + panic(fmt.Sprintf("get failed, got %q expected %q", []byte(data), []byte("v"+key))) + } + }(strconv.Itoa(i)) + } + pending.Wait() + + pending.Add(n) + for i := 0; i < n; i++ { + go func(key string) { + defer pending.Done() + err := db.Delete([]byte(key)) + if err != nil { + panic("delete failed: " + err.Error()) + } + }(strconv.Itoa(i)) + } + pending.Wait() + + pending.Add(n) + for i := 0; i < n; i++ { + go func(key string) { + defer pending.Done() + _, err := db.Get([]byte(key)) + if err == nil { + panic("get succeeded") + } + }(strconv.Itoa(i)) + } + pending.Wait() +}