Skip to content
Merged
195 changes: 179 additions & 16 deletions coinbaseadvanced/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,36 @@
"""

from typing import List
from enum import Enum
from datetime import datetime, timedelta
from cryptography.hazmat.primitives import serialization

import jwt
import hmac
import hashlib
import time
import json
import requests
from coinbaseadvanced.models.common import UnixTime

from coinbaseadvanced.models.fees import TransactionsSummary
from coinbaseadvanced.models.products import ProductsPage, Product, CandlesPage,\
from coinbaseadvanced.models.products import BidAsksPage, ProductBook, ProductsPage, Product, CandlesPage,\
TradesPage, ProductType, Granularity, GRANULARITY_MAP_IN_MINUTES
from coinbaseadvanced.models.accounts import AccountsPage, Account
from coinbaseadvanced.models.orders import OrderPlacementSource, OrdersPage, Order,\
OrderBatchCancellation, FillsPage, Side, StopDirection, OrderType


class AuthSchema(Enum):
"""
Enum representing authetication schema:
https://docs.cloud.coinbase.com/advanced-trade-api/docs/auth#authentication-schemes
"""

CLOUD_API_TRADING_KEYS = "CLOUD_API_TRADING_KEYS"
LEGACY_API_KEYS = "LEGACY_API_KEYS"


class CoinbaseAdvancedTradeAPIClient(object):
"""
API Client for Coinbase Advanced Trade endpoints.
Expand All @@ -28,11 +42,36 @@ def __init__(self,
api_key: str,
secret_key: str,
base_url: str = 'https://api.coinbase.com',
timeout: int = 10) -> None:
timeout: int = 10,
auth_schema: AuthSchema = AuthSchema.LEGACY_API_KEYS
) -> None:
self._base_url = base_url
self._host = base_url[8:]
self._api_key = api_key
self._secret_key = secret_key
self.timeout = timeout
self._auth_schema = auth_schema

@staticmethod
def from_legacy_api_keys(api_key: str,
secret_key: str):
"""
Factory method for legacy auth schema.
API keys for this schema are generated via: https://www.coinbase.com/settings/api
"""
return CoinbaseAdvancedTradeAPIClient(api_key=api_key, secret_key=secret_key)

@staticmethod
def from_cloud_api_keys(api_key_name: str,
private_key: str):
"""
Factory method for cloud auth schema (recommended by Coinbase).
API keys for this schema are generated via: https://cloud.coinbase.com/access/api
"""
return CoinbaseAdvancedTradeAPIClient(api_key=api_key_name, secret_key=private_key,
auth_schema=AuthSchema.CLOUD_API_TRADING_KEYS)

# Accounts #

# Accounts #

Expand All @@ -57,11 +96,12 @@ def list_accounts(self, limit: int = 49, cursor: str = None) -> AccountsPage:
method = "GET"
query_params = '?limit='+str(limit)

headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

if cursor is not None:
query_params = query_params + '&cursor='+cursor

headers = self._build_request_headers(method, request_path)

response = requests.get(self._base_url+request_path+query_params,
headers=headers,
timeout=self.timeout)
Expand Down Expand Up @@ -105,7 +145,8 @@ def get_account(self, account_id: str) -> Account:
request_path = f"/api/v3/brokerage/accounts/{account_id}"
method = "GET"

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path, headers=headers, timeout=self.timeout)

Expand Down Expand Up @@ -275,7 +316,8 @@ def create_order(self, client_order_id: str,
'order_configuration': order_configuration
}

headers = self._build_request_headers(method, request_path, json.dumps(payload))
headers = self._build_request_headers(method, request_path, json.dumps(payload)) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)
response = requests.post(self._base_url+request_path,
json=payload, headers=headers,
timeout=self.timeout)
Expand All @@ -300,7 +342,8 @@ def cancel_orders(self, order_ids: list) -> OrderBatchCancellation:
'order_ids': order_ids,
}

headers = self._build_request_headers(method, request_path, json.dumps(payload))
headers = self._build_request_headers(method, request_path, json.dumps(payload)) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)
response = requests.post(self._base_url+request_path,
json=payload,
headers=headers,
Expand Down Expand Up @@ -398,7 +441,8 @@ def list_orders(
if order_placement_source is not None:
query_params = self._next_param(query_params) + 'order_placement_source=' + order_placement_source.value

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path+query_params,
headers=headers,
Expand Down Expand Up @@ -495,7 +539,8 @@ def list_fills(self, order_id: str = None, product_id: str = None, start_date: d
if cursor is not None:
query_params = self._next_param(query_params) + 'cursor=' + cursor

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path+query_params,
headers=headers,
Expand Down Expand Up @@ -530,7 +575,8 @@ def get_order(self, order_id: str) -> Order:
request_path = f"/api/v3/brokerage/orders/historical/{order_id}"
method = "GET"

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path, headers=headers, timeout=self.timeout)

Expand Down Expand Up @@ -568,7 +614,8 @@ def list_products(self,
if product_type is not None:
query_params = self._next_param(query_params) + 'product_type=' + product_type.value

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path+query_params,
headers=headers,
Expand All @@ -590,7 +637,8 @@ def get_product(self, product_id: str) -> Product:
request_path = f"/api/v3/brokerage/products/{product_id}"
method = "GET"

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path, headers=headers, timeout=self.timeout)

Expand Down Expand Up @@ -624,7 +672,8 @@ def get_product_candles(
query_params = self._next_param(query_params) + 'end=' + str(int(end_date.timestamp()))
query_params = self._next_param(query_params) + 'granularity=' + granularity.value

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path+query_params,
headers=headers,
Expand Down Expand Up @@ -697,7 +746,8 @@ def get_market_trades(

query_params = self._next_param(query_params) + 'limit=' + str(limit)

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path+query_params,
headers=headers,
Expand All @@ -706,6 +756,61 @@ def get_market_trades(
trades_page = TradesPage.from_response(response)
return trades_page

def get_product_book(self, product_id: str, limit: int = None) -> ProductBook:
"""
https://docs.cloud.coinbase.com/advanced-trade-api/reference/retailbrokerageapi_getproductbook

Get a list of bids/asks for a single product.
The amount of detail shown can be customized with the limit parameter.

Args:
- product_id: The trading pair.
- limit: A pagination limit.
"""

request_path = f"/api/v3/brokerage/product_book"
method = "GET"

query_params = ''
if product_id is not None:
query_params = self._next_param(query_params) + 'product_id='+product_id

if limit is not None:
query_params = self._next_param(query_params) + 'limit='+str(limit)

headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path+query_params, headers=headers, timeout=self.timeout)

bid_asks_page = ProductBook.from_response(response)
return bid_asks_page

def get_best_bid_ask(self, product_ids: List[str] = None) -> BidAsksPage:
"""
https://docs.cloud.coinbase.com/advanced-trade-api/reference/retailbrokerageapi_getbestbidask

Get the best bid/ask for all products. A subset of all products can be returned instead by using the product_ids input.

Args:
- product_ids: Subset of all products to be returned instead.
"""

request_path = f"/api/v3/brokerage/best_bid_ask"
method = "GET"

query_params = ''
if product_ids is not None:
query_params = self._next_param(query_params) + 'product_ids='+'&product_ids='.join(product_ids)

headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path+query_params, headers=headers, timeout=self.timeout)

bid_asks_page = BidAsksPage.from_response(response)
return bid_asks_page

# Fees #

def get_transactions_summary(self,
Expand Down Expand Up @@ -739,7 +844,8 @@ def get_transactions_summary(self,
if product_type is not None:
query_params = self._next_param(query_params) + 'product_type='+product_type.value

headers = self._build_request_headers(method, request_path)
headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path+query_params,
headers=headers,
Expand All @@ -748,7 +854,59 @@ def get_transactions_summary(self,
page = TransactionsSummary.from_response(response)
return page

# Helpers #
# Common #

def get_unix_time(self) -> UnixTime:
"""
https://docs.cloud.coinbase.com/advanced-trade-api/reference/retailbrokerageapi_getunixtime

Get the current time from the Coinbase Advanced API.

"""

request_path = f"/api/v3/brokerage/time"
method = "GET"

headers = self._build_request_headers(method, request_path) if self._is_legacy_auth(
) else self._build_request_headers_for_cloud(method, self._host, request_path)

response = requests.get(self._base_url+request_path, headers=headers, timeout=self.timeout)

time = UnixTime.from_response(response)
return time

# Helpers Methods #

## Cloud Auth ##

def _build_request_headers_for_cloud(self, method, host, request_path):
uri = f"{method} {host}{request_path}"
jwt_token = self._build_jwt("retail_rest_api_proxy", uri)

return {
"Authorization": f"Bearer {jwt_token}",
}

def _build_jwt(self, service, uri):
private_key_bytes = self._secret_key.encode('utf-8')
private_key = serialization.load_pem_private_key(private_key_bytes, password=None)
jwt_payload = {
'sub': self._api_key,
'iss': "coinbase-cloud",
'nbf': int(time.time()),
'exp': int(time.time()) + 60,
'aud': [service],
'uri': uri,
}
jwt_token = jwt.encode(
jwt_payload,
private_key,
algorithm='ES256',
headers={'kid': self._api_key, 'nonce': str(int(time.time()))},
)
return jwt_token

## Legacy Auth ##

def _build_request_headers(self, method, request_path, body=''):
timestamp = str(int(time.time()))
Expand All @@ -771,5 +929,10 @@ def _create_signature(self, message):

return signature

def _is_legacy_auth(self) -> bool:
return self._auth_schema == AuthSchema.LEGACY_API_KEYS

## Others ##

def _next_param(self, query_params: str) -> str:
return query_params + ('?' if query_params == '' else '&')
12 changes: 6 additions & 6 deletions coinbaseadvanced/models/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
Object models for account related endpoints args and response.
"""

import json
from uuid import UUID
from datetime import datetime
from typing import List
import requests

from coinbaseadvanced.models.common import BaseModel
from coinbaseadvanced.models.error import CoinbaseAdvancedTradeAPIError


class AvailableBalance:
class AvailableBalance(BaseModel):
"""
Available Balance object.
"""
Expand All @@ -26,7 +26,7 @@ def __init__(self, value: str, currency: str, **kwargs) -> None:
self.kwargs = kwargs


class Account:
class Account(BaseModel):
"""
Object representing an account.
"""
Expand Down Expand Up @@ -73,12 +73,12 @@ def from_response(cls, response: requests.Response) -> 'Account':
if not response.ok:
raise CoinbaseAdvancedTradeAPIError.not_ok_response(response)

result = json.loads(response.text)
result = response.json()
account_dict = result['account']
return cls(**account_dict)


class AccountsPage:
class AccountsPage(BaseModel):
"""
Page of accounts.
"""
Expand Down Expand Up @@ -114,7 +114,7 @@ def from_response(cls, response: requests.Response) -> 'AccountsPage':
if not response.ok:
raise CoinbaseAdvancedTradeAPIError.not_ok_response(response)

result = json.loads(response.text)
result = response.json()
return cls(**result)

def __iter__(self):
Expand Down
Loading