diff --git a/api/proto/identity/identity.go b/api/proto/identity/identity.go index 856ce90f0..4895c5e63 100644 --- a/api/proto/identity/identity.go +++ b/api/proto/identity/identity.go @@ -46,7 +46,7 @@ func GetIdentityMessageType(message []byte) (MessageType, error) { if len(message) < 1 { return 0, errors.New("failed to get identity message type: no data available") } - return MessageType(message[2]), nil + return MessageType(message[0]), nil } // GetIdentityMessagePayload message payload from the identity message diff --git a/api/proto/identity/identity_test.go b/api/proto/identity/identity_test.go index 1e00df15f..d0d3c1819 100644 --- a/api/proto/identity/identity_test.go +++ b/api/proto/identity/identity_test.go @@ -3,11 +3,14 @@ package identity import ( "strings" "testing" + + "github.com/harmony-one/harmony/api/proto" ) func TestRegisterIdentityMessage(t *testing.T) { registerIdentityMessage := ConstructIdentityMessage(Register, []byte("registerIdentityMessage")) - messageType, err := GetIdentityMessageType(registerIdentityMessage) + msgPayload, err := proto.GetMessagePayload(registerIdentityMessage) + messageType, err := GetIdentityMessageType(msgPayload) if err != nil { t.Errorf("Error thrown in geting message type") } @@ -18,7 +21,8 @@ func TestRegisterIdentityMessage(t *testing.T) { func TestAcknowledgeIdentityMessage(t *testing.T) { registerAcknowledgeMessage := ConstructIdentityMessage(Acknowledge, []byte("acknowledgeIdentityMsgPayload")) - messageType, err := GetIdentityMessageType(registerAcknowledgeMessage) + msgPayload, err := proto.GetMessagePayload(registerAcknowledgeMessage) + messageType, err := GetIdentityMessageType(msgPayload) if err != nil { t.Errorf("Error thrown in geting message type") } @@ -29,6 +33,11 @@ func TestAcknowledgeIdentityMessage(t *testing.T) { func TestInvalidIdentityMessage(t *testing.T) { registerInvalidMessage := ConstructIdentityMessage(3, []byte("acknowledgeIdentityMsgPayload")) + registerInvalidMessagePayload, err := GetIdentityMessagePayload(registerInvalidMessage) + if err != nil { + t.Errorf("Error thrown in geting message type from invalid message") + } + _ = registerInvalidMessagePayload messageType, err := GetIdentityMessageType(registerInvalidMessage) if err != nil { t.Errorf("Error thrown in geting message type from invalid message")