From c6f231e02d92e4cb0772ed9d2b106bddf12b52e8 Mon Sep 17 00:00:00 2001 From: brettkolodny Date: Sun, 18 Feb 2024 09:55:14 -0500 Subject: [PATCH] feat: erlang encoding and decoding HS256 --- gleam.toml | 5 +- manifest.toml | 8 +- shell.nix | 16 +++ src/ffi.mjs | 7 -- src/gwt.gleam | 261 +++++++++++++++++++++++++++++--------------- test/gwt_test.gleam | 62 +++++++---- 6 files changed, 240 insertions(+), 119 deletions(-) create mode 100644 shell.nix delete mode 100644 src/ffi.mjs diff --git a/gleam.toml b/gleam.toml index e6b212e..4ea2a96 100644 --- a/gleam.toml +++ b/gleam.toml @@ -10,8 +10,9 @@ version = "1.0.0" # links = [{ title = "Website", href = "https://gleam.run" }] [dependencies] -gleam_stdlib = "~> 0.32" -gleam_json = "~> 0.7" +gleam_stdlib = "~> 0.34" +gleam_json = "~> 1.0" +gleam_crypto = "~> 1.3" [dev-dependencies] gleeunit = "~> 1.0" diff --git a/manifest.toml b/manifest.toml index 71b2630..3f56988 100644 --- a/manifest.toml +++ b/manifest.toml @@ -2,13 +2,15 @@ # You typically do not need to edit this file packages = [ - { name = "gleam_json", version = "0.7.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "thoas"], otp_app = "gleam_json", source = "hex", outer_checksum = "CB405BD93A8828BCD870463DE29375E7B2D252D9D124C109E5B618AAC00B86FC" }, + { name = "gleam_crypto", version = "1.3.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_crypto", source = "hex", outer_checksum = "ADD058DEDE8F0341F1ADE3AAC492A224F15700829D9A3A3F9ADF370F875C51B7" }, + { name = "gleam_json", version = "1.0.0", build_tools = ["gleam"], requirements = ["thoas", "gleam_stdlib"], otp_app = "gleam_json", source = "hex", outer_checksum = "8B197DD5D578EA6AC2C0D4BDC634C71A5BCA8E7DB5F47091C263ECB411A60DF3" }, { name = "gleam_stdlib", version = "0.34.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "1FB8454D2991E9B4C0C804544D8A9AD0F6184725E20D63C3155F0AEB4230B016" }, { name = "gleeunit", version = "1.0.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "D364C87AFEB26BDB4FB8A5ABDE67D635DC9FA52D6AB68416044C35B096C6882D" }, { name = "thoas", version = "0.4.1", build_tools = ["rebar3"], requirements = [], otp_app = "thoas", source = "hex", outer_checksum = "4918D50026C073C4AB1388437132C77A6F6F7C8AC43C60C13758CC0ADCE2134E" }, ] [requirements] -gleam_json = { version = "~> 0.7" } -gleam_stdlib = { version = "~> 0.32" } +gleam_crypto = { version = "~> 1.3"} +gleam_json = { version = "~> 1.0" } +gleam_stdlib = { version = "~> 0.34" } gleeunit = { version = "~> 1.0" } diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..8c0f708 --- /dev/null +++ b/shell.nix @@ -0,0 +1,16 @@ +with import {}; + +mkShell { + buildInputs = [ + nodejs_18 + helix + gleam + erlang + rebar3 + cheat + bat + nil + nodePackages.typescript-language-server + ]; + } + diff --git a/src/ffi.mjs b/src/ffi.mjs deleted file mode 100644 index 6b113cc..0000000 --- a/src/ffi.mjs +++ /dev/null @@ -1,7 +0,0 @@ -export const base64Encode = (string) => { - return btoa(string); -} - -export const base64Decode = (string) => { - return atob(string); -} diff --git a/src/gwt.gleam b/src/gwt.gleam index 1f7abf0..27551a0 100644 --- a/src/gwt.gleam +++ b/src/gwt.gleam @@ -1,13 +1,13 @@ // IMPORTS --------------------------------------------------------------------- +import gleam/bit_array import gleam/dict.{type Dict} import gleam/dynamic.{type Dynamic} import gleam/json.{type Json} import gleam/string -import gleam/option.{type Option} as o import gleam/list -import gleam/int import gleam/result +import gleam/crypto // TYPES ----------------------------------------------------------------------- @@ -18,11 +18,11 @@ type Payload = Dict(String, Dynamic) pub opaque type Jwt { - Jwt(header: Header, payload: Payload, signature: Option(String)) + Jwt(header: Header, payload: Payload) } -pub type DecodeErrors { - /// +pub type JwtDecodeError { + /// MissingHeader /// MissingPayload @@ -34,6 +34,12 @@ pub type DecodeErrors { InvalidPayload /// InvalidSignature + /// + NoAlg +} + +pub type Algorithm { + HS256 } // CONSTRUCTORS ---------------------------------------------------------------- @@ -46,11 +52,11 @@ pub fn new() -> Jwt { |> dict.insert("alg", dynamic.from("none")) let payload = dict.new() - Jwt(header, payload, o.None) + Jwt(header, payload) } /// -pub fn from_string(jwt_string: String) -> Result(Jwt, DecodeErrors) { +pub fn from_string(jwt_string: String) -> Result(Jwt, JwtDecodeError) { let jwt_parts = string.split(jwt_string, ".") let maybe_header = list.at(jwt_parts, 0) let maybe_payload = list.at(jwt_parts, 1) @@ -60,27 +66,124 @@ pub fn from_string(jwt_string: String) -> Result(Jwt, DecodeErrors) { Error(Nil), _ -> Error(MissingHeader) _, Error(Nil) -> Error(MissingPayload) Ok(encoded_header), Ok(encoded_payload) -> { - let header = - encoded_header - |> base64_url_safe_to_base64() - |> base64_decode_() - |> json.decode(dynamic.dict(dynamic.string, dynamic.dynamic)) - - let payload = - encoded_payload - |> base64_url_safe_to_base64() - |> base64_decode_() - |> json.decode(dynamic.dict(dynamic.string, dynamic.dynamic)) + let header = { + use res <- result.try( + encoded_header + |> bit_array.base64_url_decode() + |> result.replace_error(InvalidHeader), + ) + + use res <- result.try( + res + |> bit_array.to_string() + |> result.replace_error(InvalidHeader), + ) + + json.decode(res, dynamic.dict(dynamic.string, dynamic.dynamic)) + |> result.replace_error(InvalidHeader) + } + + let payload = { + use res <- result.try( + encoded_payload + |> bit_array.base64_url_decode() + |> result.replace_error(InvalidHeader), + ) + + use res <- result.try( + res + |> bit_array.to_string() + |> result.replace_error(InvalidHeader), + ) + + json.decode(res, dynamic.dict(dynamic.string, dynamic.dynamic)) + |> result.replace_error(InvalidHeader) + } case header, payload { Error(_), _ -> Error(InvalidHeader) _, Error(_) -> Error(InvalidPayload) - Ok(header), Ok(payload) -> Ok(Jwt(header, payload, o.None)) + Ok(header), Ok(payload) -> Ok(Jwt(header, payload)) } } } } +/// +pub fn from_signed_string( + jwt_string: String, + secret: String, +) -> Result(Jwt, JwtDecodeError) { + let jwt_parts = string.split(jwt_string, ".") + + use signature <- result.try( + list.at(jwt_parts, 2) + |> result.replace_error(MissingSignature), + ) + + use encoded_payload <- result.try( + list.at(jwt_parts, 1) + |> result.replace_error(MissingPayload), + ) + + use encoded_header <- result.try( + list.at(jwt_parts, 0) + |> result.replace_error(MissingHeader), + ) + use header_data <- result.try( + encoded_header + |> bit_array.base64_url_decode() + |> result.replace_error(InvalidHeader), + ) + use header_string <- result.try( + header_data + |> bit_array.to_string() + |> result.replace_error(InvalidHeader), + ) + + use header <- result.try( + json.decode(header_string, dynamic.dict(dynamic.string, dynamic.dynamic)) + |> result.replace_error(InvalidHeader), + ) + + use payload_data <- result.try( + encoded_payload + |> bit_array.base64_url_decode() + |> result.replace_error(InvalidHeader), + ) + use payload_string <- result.try( + payload_data + |> bit_array.to_string() + |> result.replace_error(InvalidHeader), + ) + + use payload <- result.try( + json.decode(payload_string, dynamic.dict(dynamic.string, dynamic.dynamic)) + |> result.replace_error(InvalidHeader), + ) + + use alg <- result.try( + dict.get(header, "alg") + |> result.replace_error(NoAlg), + ) + + case dynamic.string(alg) { + Ok("HS256") -> { + let sig = + get_signature(encoded_header <> "." <> encoded_payload, HS256, secret) + case sig == signature { + True -> { + Ok(Jwt(header: header, payload: payload)) + } + False -> Error(InvalidSignature) + } + } + _ -> panic as "Unimplemented signature algorithm" + } +} + +// PAYLOAD --------------------------------------------------------------------- + /// pub fn get_issuer(from jwt: Jwt) -> Result(String, Nil) { use issuer <- result.try( @@ -106,7 +209,7 @@ pub fn get_subject(from jwt: Jwt) -> Result(String, Nil) { } /// -pub fn get_claim( +pub fn get_payload_claim( from jwt: Jwt, claim claim: String, decoder decoder: fn(Dynamic) -> Result(a, List(dynamic.DecodeError)), @@ -122,138 +225,113 @@ pub fn get_claim( } /// -pub fn set_issuer(jwt: Jwt, to iss: String) -> Jwt { +pub fn set_payload_issuer(jwt: Jwt, to iss: String) -> Jwt { let new_payload = dict.insert(jwt.payload, "iss", dynamic.from(iss)) Jwt(..jwt, payload: new_payload) } /// -pub fn set_subject(jwt: Jwt, to sub: String) -> Jwt { +pub fn set_payload_subject(jwt: Jwt, to sub: String) -> Jwt { let new_payload = dict.insert(jwt.payload, "sub", dynamic.from(sub)) Jwt(..jwt, payload: new_payload) } /// -pub fn set_audience(jwt: Jwt, to aud: String) -> Jwt { +pub fn set_payload_audience(jwt: Jwt, to aud: String) -> Jwt { let new_payload = dict.insert(jwt.payload, "aud", dynamic.from(aud)) Jwt(..jwt, payload: new_payload) } /// -pub fn set_expiration(jwt: Jwt, to exp: Int) -> Jwt { +pub fn set_payload_expiration(jwt: Jwt, to exp: Int) -> Jwt { let new_payload = dict.insert(jwt.payload, "exp", dynamic.from(exp)) Jwt(..jwt, payload: new_payload) } /// -pub fn set_not_before(jwt: Jwt, to nbf: Int) -> Jwt { +pub fn set_payload_not_before(jwt: Jwt, to nbf: Int) -> Jwt { let new_payload = dict.insert(jwt.payload, "nbf", dynamic.from(nbf)) Jwt(..jwt, payload: new_payload) } /// -pub fn set_issued_at(jwt: Jwt, to iat: Int) -> Jwt { +pub fn set_payload_issued_at(jwt: Jwt, to iat: Int) -> Jwt { let new_payload = dict.insert(jwt.payload, "iat", dynamic.from(iat)) Jwt(..jwt, payload: new_payload) } /// -pub fn set_jwt_id(jwt: Jwt, to jti: String) -> Jwt { +pub fn set_payload_jwt_id(jwt: Jwt, to jti: String) -> Jwt { let new_payload = dict.insert(jwt.payload, "jti", dynamic.from(jti)) Jwt(..jwt, payload: new_payload) } /// -pub fn set_private_payload_claim( - jwt: Jwt, - set claim: String, - to value: Json, -) -> Jwt { +pub fn set_payload_claim(jwt: Jwt, set claim: String, to value: Json) -> Jwt { let new_payload = dict.insert(jwt.payload, claim, dynamic.from(value)) Jwt(..jwt, payload: new_payload) } +// HEADER ---------------------------------------------------------------------- + +pub fn set_header_claim(jwt: Jwt, set claim: String, to value: Json) -> Jwt { + let new_header = dict.insert(jwt.header, claim, dynamic.from(value)) + + Jwt(..jwt, header: new_header) +} + +// ENCODER --------------------------------------------------------------------- + pub fn to_string(jwt: Jwt) -> String { - let Jwt(header, payload, signature) = jwt + let Jwt(header, payload) = jwt let header_string = header |> dict_to_json_object() |> json.to_string() - |> base64_encode_() - |> base64_string_to_url_safe() + |> bit_array.from_string() + |> bit_array.base64_url_encode(False) let payload_string = payload |> dict_to_json_object() |> json.to_string() - |> base64_encode_() - |> base64_string_to_url_safe() + |> bit_array.from_string() + |> bit_array.base64_url_encode(False) - case signature { - o.Some(s) -> { - let base64_signature = - s - |> base64_encode_() - |> base64_string_to_url_safe() + header_string <> "." <> payload_string +} - header_string <> "." <> payload_string <> "." <> base64_signature - } - o.None -> { - header_string <> "." <> payload_string +pub fn to_signed_string(jwt: Jwt, alg: Algorithm, secret: String) -> String { + case alg { + HS256 -> { + let header_with_alg = + dict.insert(jwt.header, "alg", dynamic.from("HS256")) + let jwt_body = + Jwt(..jwt, header: header_with_alg) + |> to_string() + + let jwt_signature = + jwt_body + |> bit_array.from_string() + |> crypto.hmac(crypto.Sha256, bit_array.from_string(secret)) + |> bit_array.base64_url_encode(False) + + jwt_body <> "." <> jwt_signature } } - - header_string <> "." <> payload_string } // UTILITIES ------------------------------------------------------------------- -@external(erlang, "base64", "encode") -@external(javascript, "./ffi.mjs", "base64Encode") -fn base64_encode_(str: String) -> String - -@external(erlang, "base64", "decode") -@external(javascript, "./ffi.mjs", "base64Decode") -fn base64_decode_(str: String) -> String - -fn base64_string_to_url_safe(str: String) -> String { - str - |> string.replace("=", "") - |> string.replace("+", "-") - |> string.replace("/", "_") -} - -fn base64_url_safe_to_base64(str: String) -> String { - let padding = - str - |> string.length() - |> int.modulo(4) - |> result.unwrap(0) - |> fn(x) { - case x { - 0 -> 0 - _ -> 4 - x - } - } - |> string.repeat("=", _) - - let encoded_string = - str - |> string.replace("-", "+") - |> string.replace("_", "/") - - encoded_string <> padding -} - fn dict_to_json_object(d: Dict(String, Dynamic)) -> Json { let key_value_list = { use acc, key, value <- dict.fold(d, []) @@ -286,3 +364,14 @@ fn dict_to_json_object(d: Dict(String, Dynamic)) -> Json { json.object(key_value_list) } + +fn get_signature(data: String, algorithm: Algorithm, secret: String) -> String { + case algorithm { + HS256 -> { + data + |> bit_array.from_string() + |> crypto.hmac(crypto.Sha256, bit_array.from_string(secret)) + |> bit_array.base64_url_encode(False) + } + } +} diff --git a/test/gwt_test.gleam b/test/gwt_test.gleam index e622e92..760e292 100644 --- a/test/gwt_test.gleam +++ b/test/gwt_test.gleam @@ -3,48 +3,68 @@ import gleeunit/should import gleam/dynamic import gwt -// const signing_secret = "gleam" +const signing_secret = "gleam" pub fn main() { gleeunit.main() } -pub fn encode_unsigned_jwt_test() { +pub fn encode_decode_unsigned_jwt_test() { let jwt_string = gwt.new() - |> gwt.set_subject("1234567890") - |> gwt.set_audience("0987654321") - |> gwt.set_not_before(1_704_043_160) - |> gwt.set_expiration(1_704_046_160) - |> gwt.set_jwt_id("2468") + |> gwt.set_payload_subject("1234567890") + |> gwt.set_payload_audience("0987654321") + |> gwt.set_payload_not_before(1_704_043_160) + |> gwt.set_payload_expiration(1_704_046_160) + |> gwt.set_payload_jwt_id("2468") |> gwt.to_string() - jwt_string - |> should.equal( - "eyJ0eXAiOiJKV1QiLCJhbGciOiJub25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmJmIjoxNzA0MDQzMTYwLCJqdGkiOiIyNDY4IiwiZXhwIjoxNzA0MDQ2MTYwLCJhdWQiOiIwOTg3NjU0MzIxIn0", - ) + let maybe_jwt = gwt.from_string(jwt_string) + + maybe_jwt + |> should.be_ok() + + let assert Ok(jwt) = gwt.from_string(jwt_string) + + gwt.get_subject(jwt) + |> should.equal(Ok("1234567890")) + + jwt + |> gwt.get_payload_claim("aud", dynamic.string) + |> should.equal(Ok("0987654321")) + + jwt + |> gwt.get_payload_claim("iss", dynamic.string) + |> should.equal(Error(Nil)) } -pub fn decode_unsigned_jwt_test() { +pub fn encode_decode_signed_jwt_test() { let jwt_string = gwt.new() - |> gwt.set_subject("1234567890") - |> gwt.set_audience("0987654321") - |> gwt.set_not_before(1_704_043_160) - |> gwt.set_expiration(1_704_046_160) - |> gwt.set_jwt_id("2468") - |> gwt.to_string() + |> gwt.set_payload_subject("1234567890") + |> gwt.set_payload_audience("0987654321") + |> gwt.to_signed_string(gwt.HS256, signing_secret) - let assert Ok(jwt) = gwt.from_string(jwt_string) + gwt.from_signed_string(jwt_string, "bad secret") + |> should.be_error + + gwt.from_signed_string(jwt_string, "bad secret") + |> should.equal(Error(gwt.InvalidSignature)) + + let maybe_jwt = gwt.from_signed_string(jwt_string, signing_secret) + maybe_jwt + |> should.be_ok() + + let assert Ok(jwt) = gwt.from_signed_string(jwt_string, signing_secret) gwt.get_subject(jwt) |> should.equal(Ok("1234567890")) jwt - |> gwt.get_claim("aud", dynamic.string) + |> gwt.get_payload_claim("aud", dynamic.string) |> should.equal(Ok("0987654321")) jwt - |> gwt.get_claim("iss", dynamic.string) + |> gwt.get_payload_claim("iss", dynamic.string) |> should.equal(Error(Nil)) }