Axum By Example (JWT)

Cover Image for Axum By Example (JWT)
Justin Bender
Justin Bender

Axum By Example (JWT)

🦀 One of the best ways to understand code is to read more of it.

I would like to take a second to welcome you to this article. I hope you're having a great day. Be safe and have fun!

axum examples

Let's start with JSON Web Tokens

axum JWT example

The dependencies are going to be different on every project. So we will start with the Cargo.toml to understand what will be compiled into this project.

Cargo.toml:

[package]
name = "example-jwt"
version = "0.1.0"
edition = "2021"
publish = false
 
[dependencies]
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
jsonwebtoken = "8.0"
once_cell = "1.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
  • Lazy: is a type provided by the once_cell crate, from the sync module. It creates a static variables that initializes lazily. What does that mean? Well it will come into the program when it is used. After initialization. You can the access it from the programs data. It allows it to be initialized once. Even if there are concurrent threads.

A very simple configuration file. If you don't know about JSON Web Tokens I recommend reading more about it.

It's one of the ways we can validate someone can access API. This isn't the most secure method in the world but it is a highly used on. You will want to have extra protects as well. For example, possibly keeping track of the IP addresses being accessed from or other trigger to know that someone malicious has access to the token.

At the very minimum if someone steals your JWT. They could use that to log in and become you on the website. How? Well it's how we validate. If they can validate. They become you.

main.rs:

//! Example JWT authorization/authetication
//!
//! Run with
//!
//! ```not_rust
//! JWT_SECRET=secret cargo run -p example-jwt
//! ```.
 
use axum::{
  async_trait,
  extract::FromRequestParts,
  http::{request::Parts, StatusCode},
  response::{IntoResponse, Response},
  routing::{get, post},
  Json, RequestPartsExt, Router,
};
use axum_extra::{
  headers::{authorization::Bearer, Authorization},
  TypedHeader,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header,
Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
 
// Quick instructions
//
// - get an authorization token:
//
// curl -s \
//     -w '\n' \
//     -H 'Content-Type: application/json' \
//     -d '{"client_id":"foo","client_secret":"bar"}' \
//     http://localhost:3000/authorize
//
// - visit the protected area using the authorized token
//
// curl -s \
//     -w '\n' \
//     -H 'Content-Type: application/json' \
//     -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \
//     http://localhost:3000/protected
//
// - try to visit the protected area using an invalid token
//
// curl -s \
//     -w '\n' \
//     -H 'Content-Type: application/json' \
//     -H 'Authorization: Bearer blahblahblah' \
//     http://localhost:3000/protected
 
static KEYS: Lazy<Keys> = Lazy::new(|| {
  let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be
  set");
  Keys::new(secret.as_bytes())
});
 
#[tokio::main]
async fn main() {
  tracing_subscriber::registry()
    .with(
      tracing_subscriber::EnvFilter::try_from_default_env()
        .unwrap_or_else(|_| "example_jwt=debug".into()),
    )
    .with(tracing_subscriber::fmt::layer())
    .init();
 
  let app = Router::new()
    .route("/protected", get(protected))
    .route("/authorize", post(authorize));
 
  let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
    .await
    .unwrap();
  tracing::debug!("listening on {}", listener.local_addr().unwrap());
  axum::serve(listener, app).await.unwrap();
}
 
async fn protected(claims: Claims) -> Result<String, AuthError> {
  // Send the protected data to user
  Ok(format!(
    "Welcome to the protected area :)\nYour data:\n{}",
    claims
  ))
}
 
async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
  // Check if the user sent the credentials
  if payload.client_id.is_empty() || payload.client_secret.is_empty() {
    return Err(AuthError::MissingCredential);
  }
  // Here you can check the user credentials from a database
  if payload.client_id != "foo" || payload.client_secret != "bar" {
    return Err(AuthError::WrongCredentials);
  }
  let claims = Claims {
    sub: "b@b.com".to_owned(),
    company: "ACME".to_owned(),
    // Mandatory expiry time as UTC timestamp
    exp: 2000000000, // May 2033
  };
  // Create the authorization token
  let token = encode(&Header::default(), &claims, &KEYS.encoding)
    .map_err(|_| AuthError::TokenCreation)?;
 
  // Send the authorized token
  Ok(Json(AuthBody::new(token)))
}
 
impl Display for Claims {
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
  }
}
 
impl AuthBody {
  fn new(access_token: String) -> Self {
    Self {
      access_token,
      token_type: "Bearer".to_string(),
    }
  }
}
 
#[async_trait]
impl<S> FromRequestParts<S> for Claims
where
  S: Send + Sync,
{
  type Rejection = AuthError;
 
  async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
    // Extract the token from the authorization header
    let TypedHeader(Authorization(bearer)) = parts
      .extract::<TypedHeader<Authorization<Bearer>>>()
      .await
      .map_err(|_| AuthError::InvalidToken)?;
    // Decode the user data
    let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding,    &Validation::default())
      .map_err(|_| AuthError::InvalidToken)?;
 
    Ok(token_data.claims)
  }
}
 
impl IntoResponse for AuthError {
  fn into_response(self) -> Response {
    let (status, error_message) = match self {
      AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
      AuthError::MissingCredential => (StatusCode::BAD_REQUEST, "Missing credentials"),
      AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
      AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
    };
    let body = Json(json!({
      "error": error_message,
    }));
    (status, body).into_response()
  }
}
 
struct Keys {
  encoding: EncodingKey,
  decoding: DecodingKey,
}
 
impl Keys {
  fn new(secret: &[u8]) -> Self {
    Self {
      encoding: EncodingKey::from_secret(secret),
      decoding: DecodingKey::from_secret(secret),
    }
  }
}
 
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
  sub: String,
  company: String,
  exp: usize,
}
 
#[derive(Debug, Serialize)]
struct AuthBody {
  access_token: String,
  token_type: String,
}
 
#[derive(Debug, Deserialize)]
struct AuthPayload {
  client_id: String,
  client-secret: String,
}
 
#[derive(Debug)]
enum AuthError {
  WrongCredentials,
  MissingCredential,
  TokenCreation,
  InvalidToken,
}

There are pieces of this code that I cannot understand. So I'll have to look a bit deeper. There are many pieces to unpack. So the only place to start is going to be from the top.

Importing all of the required crates

The goal of this project is to create a server that listens on routes. With these routes we plan to create two different paths. Authorize and Protected. What we can do with is exactly what it sounds like. Signing in and authorizing. Qualifying a valid identity and receiving the protected data back.

We can accomplish this with using axum, jsonwebtoken, once_cell, serde, serde_json and tracing_subscriber. I'm not sure what once_cell and tracing_subscriber. The others you will have to understand more on your own. That will be out of scope for this one.

use axum::{
  async_trait,
  extract::FromRequestParts,
  http::{request::Parts, StatusCode},
  response::{IntoResponse, Response},
  routing::{get, post},
  Json, RequestPartsExt, Router,
};
use axum_extra::{
  headers::{authorization::Bearer, Authorization},
  TypedHeader,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header,
Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
  • once_cell: This is used to lazy load variables. It's mostly used with Arc or Mutex for synchronization primitives. It initializes the value on demand and avoids unnecessary recomputation across multi threaded environments.

  • tracing_subscriber: Part of a larger tracing ecosystem. It provides a framework for structured, composable logging and diagnostics in Rust applications. It helps configure and set up tracing instrumentation in Rust projects. It provides various subscriber implementations, which are responsible for receiving tracing events and processing them in different ways. Printing to the console, writing to a file, or sending to a distributed tracing system. We can customize the logging behavior and choose which subscribers to use in the application.

The main function and secrets

In this first part we are using the Lazy setup from once_cell that will do what we described above. It will create a static KEYS whenever it's needed across avoiding recomputations. We pull out the secret with std::env::var("JWT_SECRET") and we return Keys::new(secret.as_bytes()

We will use #[tokio::main] on our main function and allow it to be async fn main(). This will set us up for using tokio.

static KEYS: Lazy<Keys> = Lazy::new(|| {
  let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be
  set");
  Keys::new(secret.as_bytes())
});
 
#[tokio::main]
async fn main() {
  tracing_subscriber::registry()
    .with(
      tracing_subscriber::EnvFilter::try_from_default_env()
        .unwrap_or_else(|_| "example_jwt=debug".into()),
    )
    .with(tracing_subscriber::fmt::layer())
    .init();
 
  let app = Router::new()
    .route("/protected", get(protected))
    .route("/authorize", post(authorize));
 
  let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
    .await
    .unwrap();
  tracing::debug!("listening on {}", listener.local_addr().unwrap());
  axum::serve(listener, app).await.unwrap();
}
  • We create a new tracing registry to act as a central hub for collecting and managing tracings spans and events. tracing_subscriber::registry().with().init(); After init the method will be initialized to capture and process tracing events. We will collect had route events according to the specified filters and formatters.
  • Inside of the with() we will have tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "example_jwt=debug".into()) I'm not sure how to explain this part I'm a bit lost. So I'm going to continue and come back hopefully.
  • Inside of the next with() we use the tracing_subscriber::fmt::layer which is another part I don't understand yet. I also don't know how to explain it well enough to explain yet.
  • Create the app with our Router::new() to connect the two routes we plan to build for this protected and authorize.
  • Create the listener with tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap() we are starting a listener on the Tcp protocol using tokio.
  • tracing::debug!("listening on {}", listener.local_addr().unwrap()) is just a debug to the tracing setup. I won't go to into it now.
  • Start axum and connect the app together. axum::serve(listener, app).await.unwrap(). We use the listener we created and the app we created above. That's out main function for this program to sit and watch on tcp.

Structs and Enums

Struct

  • Claims
  • AuthBody
  • AuthPayload
  • Keys

Enum

  • AuthError: we create the different types of errors that we plan to encounter on this project.
struct Keys {
  encoding: EncodingKey,
  decoding: DecodingKey,
}
 
impl Keys {
  fn new(secret: &[u8]) -> Self {
    Self {
      encoding: EncodingKey::from_secret(secret),
      decoding: DecodingKey::from_secret(secret),
    }
  }
}
 
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
  sub: String,
  company: String,
  exp: usize,
}
 
#[derive(Debug, Serialize)]
struct AuthBody {
  access_token: String,
  token_type: String,
}
 
#[derive(Debug, Deserialize)]
struct AuthPayload {
  client_id: String,
  client-secret: String,
}
 
#[derive(Debug)]
enum AuthError {
  WrongCredentials,
  MissingCredential,
  TokenCreation,
  InvalidToken,
}

I don't have much to say about these. They are the basic data structures in Rust. Where you can read them and it's very self explainatory. The only thing to note is about the key implementation that we snuck in there. It' using EncodingKey and DecodingKey types which get their value from the utilization of a secret for encoding and decoding.

Let create the functions for our paths

  • protected
  • authorize
async fn protected(claims: Claims) -> Result<String, AuthError> {
  // Send the protected data to user
  Ok(format!(
    "Welcome to the protected area :)\nYour data:\n{}",
    claims
  ))
}

For the protected function we take in the claims and this is a part that confused me a tad. What is really going on here? There seem to be a lot of information missing from this section. How does it validate?

There will be a trait that we mention a bit later that will show how to validate who the user is. It's just not at this part. If the user is authorized then you can get the return value of the data as claims.

async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
  // Check if the user sent the credentials
  if payload.client_id.is_empty() || payload.client_secret.is_empty() {
    return Err(AuthError::MissingCredential);
  }
  // Here you can check the user credentials from a database
  if payload.client_id != "foo" || payload.client_secret != "bar" {
    return Err(AuthError::WrongCredentials);
  }
  let claims = Claims {
    sub: "b@b.com".to_owned(),
    company: "ACME".to_owned(),
    // Mandatory expiry time as UTC timestamp
    exp: 2000000000, // May 2033
  };
  // Create the authorization token
  let token = encode(&Header::default(), &claims, &KEYS.encoding)
    .map_err(|_| AuthError::TokenCreation)?;
 
  // Send the authorized token
  Ok(Json(AuthBody::new(token)))
}

On our authorize function we will take in the payload and have to make sure that it's is the correct username and password. This does not have any database connection. But this would be requiring the database connection and validation. It returns a simple claims object and uses a newly created authorization token.

What about the traits

Well this is a confusing part about this program setup. At least for me initially. These need to be setup on the Claims struct.

#[async_trait]
impl<S> FromRequestParts<S> for Claims
where
  S: Send + Sync,
{
  type Rejection = AuthError;
 
  async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
    // Extract the token from the authorization header
    let TypedHeader(Authorization(bearer)) = parts
      .extract::<TypedHeader<Authorization<Bearer>>>()
      .await
      .map_err(|_| AuthError::InvalidToken)?;
    // Decode the user data
    let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding,    &Validation::default())
      .map_err(|_| AuthError::InvalidToken)?;
 
    Ok(token_data.claims)
  }
}

When the value comes from the request for the claims value. This is where the validation of having the bearer token is happening. This is harder to understand and keep up with. We create the function from_request_parts. That breaks apart the parts as a &mut Parts. We then check to make sure that a correct bearer token is there. Not that we are validating this token. We are just breaking it down and making sure it's a JWT bearer token setup.

impl IntoResponse for AuthError {
  fn into_response(self) -> Response {
    let (status, error_message) = match self {
      AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
      AuthError::MissingCredential => (StatusCode::BAD_REQUEST, "Missing credentials"),
      AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
      AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
    };
    let body = Json(json!({
      "error": error_message,
    }));
    (status, body).into_response()
  }
}

We setup the IntoResponse for the AuthError that creates and sends back the error. Nothing to special. Takes in all the possible errors and returns a Response of the status and body.

The End 🧙🦀

This is a very high level overview of the JWT example for Axum. I find it easier to understand by looking at the example implementations. The art of other developers that you can learn from and build off of. Have fun. See you next time.