Updating tokens
The OAuth2 system is premised on having a token expiration process.
This can cause problems when you need to call a service asynchronously many times. In this period of time the access token may expire.
This tutorial provides a way to deal with this problem.
Solving¶
- Database to save the tokens.
- Worker to update the refresh token when there is 1 hour left to expire.
- Microservice to provide the most up-to-date token.
Using Python and PostgreSQL¶
This example uses Python as the programming language and PostgreSQL as the database.
GitHub - Code¶
This GitHub folder has all the files mentioned in the tutorial.
Install Dependencies¶
The dependencies used in this tutorial are:
- pysqlx-engine: SQL Engine
- httpx: HTTP Client
$ pip install httpx pysqlx-engine
Requirement already satisfied: httpx in ./.pyenv/versions/3.10.1/lib/python3.10/site-packages (1.4.36)
Collecting pysqlx-engine
Downloading pysqlx_engine-2.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.0/3.0 MB 25.9 MB/s eta
...
...
File Structure¶
The file structure is as follows:
Code¶
database.py¶
This file contains the database connection and the class that represents the table.
The TokenDB
class will be used to save the tokens in the database.
from datetime import datetime
from pysqlx_engine import PySQLXEngineSync, BaseRow
from pydantic import Field, Json
from uuid import UUID
from os import environ
QUARTILE_ID = "febdb7d2-e95d-11ed-9255-65636df51340"
"""
This is the id of the row that will be used to store the tokens.
This is a constant value and should not be changed.
"""
class TokenTable(BaseRow):
id: UUID = QUARTILE_ID
access_token: str
access_token_expires: int
refresh_token: str
updating: bool
json_data: Json
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class TokenDB:
def __init__(self):
"""
getenv("DATABASE_URL") is used to get the database url
from the environment variables.
example: postgresql://USER:PASSWORD@HOST:PORT/DATABASE
"""
self.uri = environ["DATABASE_URL"]
self.db = PySQLXEngineSync(uri=self.uri)
self.db.connect()
def create_table(self):
"""
Creating the table if it does not exist.
"""
sql = """
CREATE TABLE IF NOT EXISTS quartile_tokens (
id UUID NOT NULL,
access_token TEXT,
refresh_token TEXT,
access_token_expires INTEGER,
updating BOOLEAN,
json_data JSON,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
PRIMARY KEY (id)
);
"""
self.db.execute(sql=sql)
def select_by_id(self, id: UUID) -> TokenTable:
"""
Selecting a token by id from the database and
returning it as an instance of the TokenTable class.
"""
sql = "SELECT * FROM quartile_tokens WHERE id = :id;"
return self.db.query_first(
sql=sql,
parameters={"id": id},
model=TokenTable,
)
def insert(self, data: TokenTable) -> TokenTable:
"""
Inserting tokens and json_data on the database.
"""
sql = """
INSERT INTO quartile_tokens (
id,
access_token,
refresh_token,
access_token_expires,
updating,
json_data,
created_at,
updated_at
) VALUES (
:id,
:access_token,
:refresh_token,
:access_token_expires,
:updating,
:json_data,
:created_at,
:updated_at
);
"""
params = data.dict()
self.db.execute(sql=sql, parameters=params)
return self.select_by_id(id=data.id)
def update(self, data: TokenTable) -> TokenTable:
"""
Updating tokens and json_data on the database.
"""
sql = """
UPDATE quartile_tokens SET
access_token = :access_token,
refresh_token = :refresh_token,
access_token_expires = :access_token_expires,
updating = :updating,
json_data = :json_data,
updated_at = :updated_at
WHERE id = :id;
"""
params = data.dict()
self.db.execute(sql=sql, parameters=params)
return self.select_by_id(id=data.id)
def update_status(self, id: UUID, updating: bool) -> TokenTable:
"""
Updating status is used to prevent multiple requests to the API.
When updating is True, the API will not be called.
"""
sql = """
UPDATE quartile_tokens SET
updating = :updating,
updated_at = :updated_at
WHERE id = :id;
"""
params = {"id": id, "updating": updating, "updated_at": datetime.now()}
self.db.execute(sql=sql, parameters=params)
return self.select_by_id(id=id)
schema.py¶
This file contains the models that represent the tokens.
The AuthToken
class will be used to parse the response from the Quartile API.
from datetime import datetime
from pydantic import BaseModel as _BaseModel, Field
class BaseModel(_BaseModel):
class Config:
orm_mode = True
ignore_extra = True
allow_population_by_field_name = True
arbitrary_types_allowed = True
class Authorization(BaseModel):
token: str
type: str
expires_in: int = Field(..., alias="expiresIn")
expires_at: datetime = Field(..., alias="expiresAt")
not_before: int = Field(..., alias="notBefore")
note: str
class Refresh(BaseModel):
token: str
expires_in: int = Field(..., alias="expiresIn")
expires_at: datetime = Field(..., alias="expiresAt")
note: str
class AuthToken(BaseModel):
authorization: Authorization
refresh: Refresh
consumer.py¶
This file contains the code that will consume the token, and you can see that the token is returned when the updating
field is False
!
This status is changed by worker.py
when the token is updated.
import logging
from time import sleep
from database import QUARTILE_ID, TokenDB
from schema import AuthToken
logging.basicConfig(level=logging.INFO)
def get_access_token() -> AuthToken:
"""
Gets an access token from the TokenDB.
This function continuously checks the TokenDB for an available access
token by calling the select_by_status method of the TokenDB class.
If an access token is found, it is returned as an instance of the AuthToken class.
If no access token is found, the function waits for 5 seconds and tries again.
Returns:
AuthToken: An instance of the AuthToken class representing the Quartile token.
"""
while True:
logging.info("getting token...")
resp = TokenDB().select_by_id(id=QUARTILE_ID)
# check if resp is not None and resp.updating is False
if resp and resp.updating is False:
logging.info("token found.")
return AuthToken.parse_obj(resp.json_data)
logging.info("token not found ot updating, waiting 5 seconds...")
sleep(5)
if __name__ == "__main__":
data = get_access_token()
logging.info(data.json(indent=4))
worker.py¶
The worker.py
is responsible for updating the token when it is about to expire.
The access token expires in 12 hours, so the worker will update the token when there is 1 hour left to expire.
For example, you can put on the worker a scheduler to run every 30 minutes. Or you can run it in a separate process.
Some cloud companies provide simple ways to run this kind of routine using a timer trigger, for example:
from datetime import datetime, timedelta
import logging
import httpx
from os import environ
from database import QUARTILE_ID, TokenDB, TokenTable
from schema import AuthToken
logging.basicConfig(level=logging.INFO)
# This is the base uri for the Quartile API.
BASE_URI = "https://api.quartile.com/auth/v2"
# set your subscription key in your environment variables,
# You can find your subscription key in the Developer Portal.
SUBSCRIPTION_KEY = environ["QUARTILE_SUBSCRIPTION_KEY"]
# set your username(email) in your environment variables, the same used in the Portal
USERNAME = environ["QUARTILE_YOUR_USERNAME"]
# set your password in your environment variables, the same used in the Portal
PASSWORD = environ["QUARTILE_YOUR_PASSWORD"]
# HTTP - API
def login() -> AuthToken:
uri = f"{BASE_URI}/login"
body = {"username": USERNAME, "password": PASSWORD}
headers = {"Subscription-Key": SUBSCRIPTION_KEY}
resp = httpx.post(url=uri, json=body, headers=headers, timeout=10)
logging.info(f"login... return status code: {resp.status_code}")
assert resp.status_code == 201, resp.text
data = resp.json()
return AuthToken.parse_obj(data)
def refresh_tokens(refresh_token: str) -> AuthToken:
uri = f"{BASE_URI}/refresh"
body = {"token": refresh_token}
headers = {"Subscription-Key": SUBSCRIPTION_KEY}
resp = httpx.post(url=uri, json=body, headers=headers, timeout=10)
logging.info(f"refreshing tokens... return status code: {resp.status_code}")
assert resp.status_code == 201, resp.text
data = resp.json()
return AuthToken.parse_obj(data)
# DATABASE - POSTGRES
def insert_token(db: TokenDB):
tokens = login()
data = TokenTable(
id=QUARTILE_ID,
access_token=tokens.authorization.token,
access_token_expires=tokens.authorization.expires_at.timestamp(),
refresh_token=tokens.refresh.token,
updating=False,
json_data=tokens.json(),
)
db.insert(data=data)
logging.info("tokens inserted")
def update_token(db: TokenDB, row: TokenTable):
# get new set of tokens
try:
tokens = refresh_tokens(refresh_token=row.refresh_token)
logging.info("tokens refreshed")
except AssertionError:
# if the refresh token is expired, then make a
# login request to get a new set of tokens.
logging.info("refresh token expired, getting new tokens...")
tokens = login()
logging.info("tokens refreshed")
data = TokenTable(
id=QUARTILE_ID,
access_token=tokens.authorization.token,
access_token_expires=tokens.authorization.expires_at.timestamp(),
refresh_token=tokens.refresh.token,
updating=False,
json_data=tokens.json(),
)
# update tokens
db.update(data=data)
logging.info("tokens updated")
# MAIN
def main():
logging.info("starting worker...")
# create a new instance of TokenDB class
db = TokenDB()
# try create quartile_tokens table if not exists,
# you can remove this line if you already have the table
db.create_table()
# get one row if the QUARTILE_ID is equal to the id column
row = db.select_by_id(id=QUARTILE_ID)
# case the row is None, then make a login request to get a new set of tokens.
if row is None:
logging.info("inserting tokens...")
insert_token(db=db)
return # exit the function
# get datetime utc now minus 1 hour
now = int((datetime.utcnow() - timedelta(hours=1)).timestamp())
# check if the access_token_expires is greater than or equal now minus 5
if row.access_token_expires <= now:
logging.info("updating tokens...")
# change status to updating
# this will prevent the consumer function to get the invalid token
db.update_status(id=QUARTILE_ID, updating=True)
# try to update the tokens
update_token(db=db, row=row)
if __name__ == "__main__":
main()