Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,35 @@ class Ethereum (
private val dappMetadata: DappMetadata,
sdkOptions: SDKOptions? = null,
private val logger: Logger = DefaultLogger,
private val communicationClientModule: CommunicationClientModule = CommunicationClientModule(context)
): EthereumEventCallback {
private var connectRequestSent = false

private val communicationClient: CommunicationClient? by lazy {
communicationClientModule.provideCommunicationClient(this)
}

private val storage = communicationClientModule.provideKeyStorage()
private val coroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.Main)

private val communicationClientModule: CommunicationClientModuleInterface = CommunicationClientModule(context),
private val infuraProvider: InfuraProvider? = sdkOptions?.let {
if (it.infuraAPIKey.isNotEmpty()) {
InfuraProvider(it.infuraAPIKey)
} else {
null
}
}
): EthereumEventCallback {
private var connectRequestSent = false

val communicationClient: CommunicationClient? by lazy {
communicationClientModule.provideCommunicationClient(this)
}

private val storage = communicationClientModule.provideKeyStorage()

// Ethereum LiveData
private val _ethereumState = MutableLiveData(EthereumState("", "", ""))
private val currentEthereumState: EthereumState
get() = checkNotNull(ethereumState.value)
val ethereumState: LiveData<EthereumState> get() = _ethereumState

// Expose plain variables for developers who prefer not using observing live data via ethereumState
val chainId: String
get() = if (currentEthereumState.chainId.isEmpty()) { currentEthereumState.chainId } else { cachedChainId }
val selectedAddress: String
get() = if (currentEthereumState.selectedAddress.isEmpty()) { currentEthereumState.selectedAddress } else { cachedAccount }

private var cachedChainId = ""
private var cachedAccount = ""

var selectedAddress: String = ethereumState.value?.selectedAddress.takeIf { !it.isNullOrEmpty() } ?: cachedAccount
var chainId: String = ethereumState.value?.selectedAddress.takeIf { !it.isNullOrEmpty() } ?: cachedChainId

// Toggle SDK tracking
var enableDebug: Boolean = true
set(value) {
Expand Down Expand Up @@ -106,6 +101,7 @@ class Ethereum (
)
)
if (account.isNotEmpty()) {
selectedAddress = account
storage.putValue(account, key = SessionManager.SESSION_ACCOUNT_KEY, SessionManager.SESSION_CONFIG_FILE)
}
}
Expand All @@ -119,6 +115,7 @@ class Ethereum (
)
)
if (newChainId.isNotEmpty()) {
chainId = newChainId
storage.putValue(newChainId, key = SessionManager.SESSION_CHAIN_ID_KEY, SessionManager.SESSION_CONFIG_FILE)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.metamask.androidsdk

import org.json.JSONObject

class InfuraProvider(private val infuraAPIKey: String, private val logger: Logger = DefaultLogger) {
open class InfuraProvider(private val infuraAPIKey: String, private val logger: Logger = DefaultLogger) {
val rpcUrls: Map<String, String> = mapOf(
// ###### Ethereum ######
// Mainnet
Expand Down Expand Up @@ -70,7 +70,7 @@ class InfuraProvider(private val infuraAPIKey: String, private val logger: Logge
return !rpcUrls[chainId].isNullOrEmpty()
}

fun makeRequest(request: RpcRequest, chainId: String, dappMetadata: DappMetadata, callback: ((Result) -> Unit)?) {
open fun makeRequest(request: RpcRequest, chainId: String, dappMetadata: DappMetadata, callback: ((Result) -> Unit)?) {
val httpClient = HttpClient()

val devicePlatformInfo = DeviceInfo.platformDescription
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import java.lang.reflect.Type

class SessionManager(
private val store: SecureStorage,
private var sessionDuration: Long = 30 * 24 * 3600, // 30 days default
var sessionDuration: Long = 30 * 24 * 3600, // 30 days default
private val logger: Logger = DefaultLogger
) {
var sessionId: String = ""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@

package io.metamask.androidsdk

import android.content.ComponentName
import android.content.Context
import android.os.IBinder
import io.metamask.androidsdk.KeyExchangeMessageType.*
import io.metamask.nativesdk.IMessegeService
import io.metamask.androidsdk.Event.*
import io.metamask.androidsdk.MockInfuraProvider
import org.json.JSONObject
import org.junit.Assert.*
import org.junit.Before
import org.junit.Test
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import org.mockito.MockitoAnnotations
import org.mockito.kotlin.mock
import org.mockito.Mockito

import org.robolectric.RobolectricTestRunner
import org.junit.runner.RunWith
import org.mockito.Mockito.atLeastOnce
import org.mockito.Mockito.verify
import org.mockito.Mockito.`when`
import org.mockito.kotlin.any

@RunWith(RobolectricTestRunner::class)
class EthereumTests {

private lateinit var context: Context

private lateinit var mockEthereumEventCallback: MockEthereumEventCallback
private lateinit var logger: Logger
private lateinit var keyExchange: KeyExchange
private lateinit var sessionManager: SessionManager
private lateinit var mockClientServiceConnection: MockClientServiceConnection
private lateinit var mockClientMessageServiceCallback: MockClientMessageServiceCallback
private lateinit var mockCrypto: MockCrypto
private lateinit var mockTracker: MockTracker

private lateinit var mockCommunicationClientModule: MockCommunicationClientModule
private lateinit var ethereum: Ethereum
private lateinit var mockStorage: MockKeyStorage
private lateinit var communicationClient: CommunicationClient
private lateinit var mockInfuraProvider: MockInfuraProvider

@Before
fun setup() {
MockitoAnnotations.openMocks(this)
context = mock()

logger = TestLogger
mockEthereumEventCallback = MockEthereumEventCallback()
mockClientServiceConnection = MockClientServiceConnection()
mockClientMessageServiceCallback = MockClientMessageServiceCallback()

mockCrypto = MockCrypto()
mockTracker = MockTracker()
keyExchange = KeyExchange(mockCrypto, logger)
mockStorage = MockKeyStorage()
sessionManager = SessionManager(mockStorage)
mockInfuraProvider = MockInfuraProvider(SDKOptions(infuraAPIKey = "01234567").infuraAPIKey, logger)

mockCommunicationClientModule = MockCommunicationClientModule(
context,
mockStorage,
sessionManager,
keyExchange,
mockClientServiceConnection,
mockClientMessageServiceCallback,
mockTracker,
logger
)
ethereum = Ethereum(
context,
DappMetadata("testApp","http://www.testapp.com", iconUrl = null, base64Icon = null),
sdkOptions = SDKOptions(infuraAPIKey = "01234567"),
logger,
mockCommunicationClientModule,
mockInfuraProvider
)
communicationClient = ethereum.communicationClient!!
}

@Test
fun testUpdateAccount() = runBlocking {
val testAccount = "0x12345"
ethereum.updateAccount(testAccount)
delay(10)
assertEquals(testAccount, ethereum.selectedAddress)
assertEquals(testAccount, mockStorage.getValue(SessionManager.SESSION_ACCOUNT_KEY, SessionManager.SESSION_CONFIG_FILE))
}

@Test
fun testUpdateChainId() = runBlocking {
val testChainId = "0x1"
ethereum.updateChainId(testChainId)
delay(10)
assertEquals(testChainId, ethereum.chainId)
assertEquals(testChainId, mockStorage.getValue(SessionManager.SESSION_CHAIN_ID_KEY, SessionManager.SESSION_CONFIG_FILE))
}

@Test
fun testEthereumConnect() {
val testResult: Result = Result.Success.Item("0x123456")
var callbackResult: Result? = null

prepareCommunicationClient()

ethereum.connect { result ->
callbackResult = result
}

val requestId = findRequestIdForAccountRequest(EthereumMethod.ETH_REQUEST_ACCOUNTS)
communicationClient.completeRequest(requestId, testResult)

assertTrue(callbackResult is Result.Success)
assertEquals(callbackResult, testResult)

val trackedEvent = mockTracker.trackedEvent
assertEquals(trackedEvent, SDK_CONNECTION_AUTHORIZED)
assertNotNull(mockTracker.trackedEventParams)
assertEquals(SDK_CONNECTION_AUTHORIZED.value, mockTracker.trackedEventParams?.get("event"))
}

@Test
fun testEthereumConnectError() {
val errorCode = 4001
val errorMessage = "User rejected request"
val testResult: Result = Result.Error(RequestError(errorCode, errorMessage))
var callbackResult: Result? = null

prepareCommunicationClient()

// Assuming the connect method modifies the internal state and captures results
ethereum.connect { result ->
callbackResult = result
}

// Simulate the completion of the request made by connect
val requestId = findRequestIdForAccountRequest(EthereumMethod.ETH_REQUEST_ACCOUNTS)
communicationClient.completeRequest(requestId, testResult)

assertTrue(callbackResult is Result.Error)
assertEquals(testResult, callbackResult)

val trackedEvent = mockTracker.trackedEvent
assertEquals(SDK_CONNECTION_REJECTED, trackedEvent)
assertNotNull(mockTracker.trackedEventParams)
assertEquals(SDK_CONNECTION_REJECTED.value, mockTracker.trackedEventParams?.get("event"))
}

@Test
fun testConnectWith() {
val params: MutableMap<String, Any> = mutableMapOf(
"from" to "0x12345",
"to" to "0x98765",
"amount" to "0x1"
)

val transactionRequest = EthereumRequest(
method = EthereumMethod.ETH_SEND_TRANSACTION.value,
params = listOf(params)
)

var callbackResult: Result? = null

prepareCommunicationClient()

ethereum.connectWith(transactionRequest) { result ->
callbackResult = result
}

val requestId = findRequestIdForAccountRequest(EthereumMethod.METAMASK_CONNECT_WITH)
val testResult: Result = Result.Success.Item("0x24680")
communicationClient.completeRequest(requestId, testResult)

assertTrue(callbackResult is Result.Success)
assertEquals(callbackResult, testResult)

val trackedEvent = mockTracker.trackedEvent
assertEquals(trackedEvent, SDK_CONNECTION_AUTHORIZED)
assertNotNull(mockTracker.trackedEventParams)
assertEquals(SDK_CONNECTION_AUTHORIZED.value, mockTracker.trackedEventParams?.get("event"))
}

@Test
fun testConnectSign() {
val messageToSign = "Sign this message"
var callbackResult: Result? = null

prepareCommunicationClient()

ethereum.connectSign(messageToSign) { result ->
callbackResult = result
}

val requestId = findRequestIdForAccountRequest(EthereumMethod.METAMASK_CONNECT_SIGN)
val testResult: Result = Result.Success.Item("0xdhjdheeeeeew")
communicationClient.completeRequest(requestId, testResult)

// Assertions to verify the correct handling
assertTrue(callbackResult is Result.Success)
assertEquals(callbackResult, testResult)

val trackedEvent = mockTracker.trackedEvent
assertEquals(Event.SDK_CONNECTION_AUTHORIZED, trackedEvent)
assertNotNull(mockTracker.trackedEventParams)
assertEquals(SDK_CONNECTION_AUTHORIZED.value, mockTracker.trackedEventParams?.get("event"))
}

@Test
fun testUpdateSessionDuration() {
val newDuration = 10 * 24 * 3600L // 10 days
runBlocking {
ethereum.updateSessionDuration(newDuration)
delay(10)

// Ensure session duration in session manager is updated
assertEquals(newDuration, sessionManager.sessionDuration)
}
}

@Test
fun testClearSession() {
mockStorage.putValue("0x1", key = SessionManager.SESSION_CHAIN_ID_KEY, SessionManager.SESSION_CONFIG_FILE)
assertFalse(mockStorage.isClear())
ethereum.clearSession()
assertTrue(mockStorage.isClear())
}

@Test
fun testMetaMaskOpenedForUserInteraction() {
val request = EthereumRequest(method = EthereumMethod.ETH_SEND_TRANSACTION.value, params = listOf("to: '0x456', value: '1000'"))
ethereum.connect {}

ethereum.sendRequest(request)

// Assuming `openMetaMask` does something observable like firing an intent
verify(context, atLeastOnce()).startActivity(any())
}

@Test
fun testReadOnlyRequestUsingInfura() {
val request = EthereumRequest(method = EthereumMethod.ETH_GET_BALANCE.value, params = listOf("0x123", "latest"))
val mockResponse = "{\"balance\": \"1000\"}"
mockInfuraProvider.mockResponse = mockResponse

ethereum.connect {}

ethereum.sendRequest(request) { result ->
assertTrue(result is Result.Success)
when (result) {
is Result.Success.Item -> {
assertEquals(mockResponse, result.value)
}
else -> {
fail("Result should be success")
}
}
}
}

private fun findRequestIdForAccountRequest(method: EthereumMethod): String {
return communicationClient.submittedRequests.entries.find {
it.value.request.method == method.value
}?.key ?: throw IllegalStateException("No account request found")
}

private fun prepareCommunicationClient() {
val mockBinder = Mockito.mock(IBinder::class.java)
val mockMessageService = Mockito.mock(IMessegeService::class.java)
`when`(IMessegeService.Stub.asInterface(mockBinder)).thenReturn(mockMessageService)

// mock service connection
mockClientServiceConnection.onServiceConnected(ComponentName(context, "Service"), mockBinder)

// mock receiver
val receiverKeyExchange = KeyExchange(MockCrypto(), logger)

// exchange public keys
val receiverKeyExchangeMessage = KeyExchangeMessage(KEY_HANDSHAKE_ACK.name, receiverKeyExchange.publicKey)
val senderKeyExchangeMessage = KeyExchangeMessage(KEY_HANDSHAKE_ACK.name, keyExchange.publicKey)

keyExchange.nextKeyExchangeMessage(receiverKeyExchangeMessage)
receiverKeyExchange.nextKeyExchangeMessage(senderKeyExchangeMessage)

// mock key exchange complete
keyExchange.complete()

// mock receiving ready message
val readyMessage = JSONObject().apply {
put(MessageType.TYPE.value, MessageType.READY.value)
}.toString()
val encryptedReadyMessage = receiverKeyExchange.encrypt(readyMessage)

// simulate MetaMask Ready
communicationClient.handleMessage(encryptedReadyMessage)
}
}
Loading