Commuting Monad Transformers

2024-06-27

One of the issues in choosing a monad transformer stack is that of commutativity. Given any two monad transformers, composing them may give the same monad regardless of order or it may not.

It's useful to know which ones do, so you don't have to care about the semantics of your monad's operations, and I thought Lean would be the perfect language to showcase this -- it's used as both a functional programming language and a proof assistant, so it's well-suited to showing the (non)commutativity of monad transformers.

Preliminaries

Let's first show the definitions that lean uses of some of the most commonly used transformers:

def StateT (σ : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v) :=
  σ → m (α × σ)

def ReaderT (ρ : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v) :=
  ρ → m α

def ExceptT (ε : Type u) (m : Type u → Type v) (α : Type u) : Type v :=
  m (Except ε α)

The objective is to show for which pairs of transformers order of composition matters, though we don't necessarily want an exhaustive list -- a methodology is more interesting.

ReaderT and ExceptT

To start, let's pick two of these -- say, ReaderT and ExceptT. An informative proof, though unnecessarily long, is as follows:

theorem equiv_ReaderT_ExceptT : ReaderT ρ (ExceptT ε m) α = ExceptT ε (ReaderT ρ m) α := by
  unfold ReaderT
  unfold ExceptT
  simp
  /-
  ρ ε : Type u_1
  m : Type u_1 → Type u_2
  α : Type u_1
  ⊢ (ρ → ExceptT ε m α) = ExceptT ε (fun α => ρ → m α) α
  -/

By moving the cursor along the lines of the proof, you can see the definitions of each term in the infoview as they unfold, and verify each step. In the comment, the cursor is placed in between unfolding ReaderT and ExceptT.

However, a shorter proof is the following:

theorem equiv_ReaderT_ExceptT : ReaderT ρ (ExceptT ε m) α = ExceptT ε (ReaderT ρ m) α := by
  rfl

In this case, Lean is able to see that the two terms are definitionally equal, by unfolding definitions behind-the-scenes.

But what happens if we promote ReaderT to StateT instead? By their definitions, you can see that the only difference is StateT returns its argument in the base monad. Does this do anything?

StateT and ExceptT

The goal is to figure out whether the following two are the same:

theorem equiv_ExceptT_StateT : StateT σ (ExceptT ε m) α = ExceptT ε (StateT σ m) α := by
  rfl
  /-
  The rfl tactic failed. Possible reasons:
  - The goal is not a reflexive relation (neither `=` nor a relation with a @[refl] lemma).
  - The arguments of the relation are not equal.
  Try using the reflexivitiy lemma for your relation explicitly, e.g. `exact Eq.rfl`.
  
  σ ε : Type u_1
  m : Type u_1 → Type u_2
  α : Type u_1
  ⊢ StateT σ (ExceptT ε m) α = ExceptT ε (StateT σ m) α
  -/

Hrm, something broke in going from ReaderT to StateT, even though the two look so alike. We can use an approach similar to the first proof to see exactly what goes wrong, by unfolding definitions and then comparing the two terms at their simplest:

theorem equiv_ExceptT_StateT : StateT σ (ExceptT ε m) α = ExceptT ε (StateT σ m) α := by
  unfold StateT
  unfold ExceptT
  simp
  /-
  σ ε : Type u_1
  m : Type u_1 → Type u_2
  α : Type u_1
  ⊢ (σ → m (Except ε (α × σ))) = (σ → m (Except ε α × σ))
  -/

And therein lies the problem: the innermost term on the left is α × σ while on the right it's Except ε α × σ.

ReaderT and StateT

These commute, however, they commute propositionally, not by definition. In other words, the proof is not merely an appeal to definitions, but a bit more work has to be done. But only a bit.

If we went the same route as above, we get the following:

theorem equiv_StateT_ReaderT : ReaderT ρ (StateT σ m) α = StateT σ (ReaderT ρ m) α := by
  unfold ReaderT
  unfold StateT
  simp
  /-
  ρ σ : Type u_1
  m : Type u_1 → Type u_2
  α : Type u_1
  ⊢ (ρ → σ → m (α × σ)) = (σ → ρ → m (α × σ))
  -/

We see the left-hand side has the ReaderT environment before the StateT state, and the right-hand side has it flipped. This is why we can't use the rfl tactic -- but it's clear that if we swap the positions of the arguments, they'd be the same.

There are a few ways to show this, but I've chosen the following:

def TypeEquiv (α β : Type) : Prop :=
  ∃ f : α -> β, ∃ g : β -> α, f ∘ g = id ∧ g ∘ f = id
infixl:25 " ≃ " => TypeEquiv

In words, two types are equivalent if we can exhibit an isomorphism between them. Note that it's not enough to just supply arbitrary functions with the right types; otherwise, we could show that Nat and Unit are equivalent, for example, by fun x : Nat => () and fun _ => 0.

Clearly, we can show any type is equivalent to itself:

theorem equiv_refl { α : Type } : α ≃ α := by
  exists id
  exists id

(In fact, we can even show it's an equivalence relation¹). So in this sense, equivalence is an extension of equality. With this new concept, we can show commutativity up to isomorphism:

theorem equiv_StateT_ReaderT : ReaderT ρ (StateT σ m) α ≃ StateT σ (ReaderT ρ m) α := by
  exists flip
  exists flip

Why can't this work for the previous example, StateT and ExceptT? Recall that one version unfolded to

StateT σ (ExceptT ε m) α = σ → m (Except ε (α × σ))

The other, to

ExceptT ε (StateT σ m) α = σ → m (Except ε α × σ)

Informally, the second is able to recover modified state whether an exception was thrown or not, and the first can only do so if there is no exception.


(1) Here are the proofs that our TypeEquiv relation is an equivalence relation. We already showed reflexivity.

theorem equiv_comm { α β : Type } : α ≃ β -> β ≃ α := by
  intro asm
  let ⟨f, ⟨g, h_iso⟩⟩ := asm
  have : g ∘ f = id ∧ f ∘ g = id := ⟨h_iso.right, h_iso.left⟩
  exists g
  exists f

theorem equiv_trans { α β γ : Type } : α ≃ β -> β ≃ γ -> α ≃ γ := by
  intro eq_αβ eq_βγ
  let ⟨αβ, ⟨βα, αβ_iso⟩⟩ := eq_αβ
  let ⟨βγ, ⟨γβ, βγ_iso⟩⟩ := eq_βγ
  exists βγ ∘ αβ
  exists βα ∘ γβ
  constructor
  calc
    (βγ ∘ αβ) ∘ βα ∘ γβ
    _ = βγ ∘ (αβ ∘ βα) ∘ γβ := rfl
    _ = βγ ∘ id ∘ γβ := by simp [αβ_iso]
    _ = βγ ∘ γβ := rfl
    _ = id := by simp [βγ_iso]
  calc
    (βα ∘ γβ) ∘ βγ ∘ αβ
    _ = βα ∘ (γβ ∘ βγ) ∘ αβ := rfl
    _ = βα ∘ id ∘ αβ := by simp [βγ_iso]
    _ = βα ∘ αβ := rfl
    _ = id := by simp [αβ_iso]