aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorAmir Saeid <amir@glgdgt.com>2026-02-14 12:53:37 +0000
committerAmir Saeid <amir@glgdgt.com>2026-02-14 12:53:37 +0000
commit33c328fe9e08e642b28b310f9ab7f2fa704a3a2f (patch)
tree9e69e52e31dd592ef5b5cd3050329dc0d9f67dc7 /core/src
parent21cd15d7a259e3ada2401e2585455587a56bdfbd (diff)
Add tests
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/com/codiff/fairstream/Fair.scala52
-rw-r--r--core/src/main/scala/com/codiff/fairstream/FairT.scala110
-rw-r--r--core/src/test/scala/com/codiff/fairstream/PythagoreanSuite.scala99
3 files changed, 229 insertions, 32 deletions
diff --git a/core/src/main/scala/com/codiff/fairstream/Fair.scala b/core/src/main/scala/com/codiff/fairstream/Fair.scala
index 1436eba..84255db 100644
--- a/core/src/main/scala/com/codiff/fairstream/Fair.scala
+++ b/core/src/main/scala/com/codiff/fairstream/Fair.scala
@@ -7,7 +7,15 @@ sealed trait Fair[+A]
object Fair {
case object Nil extends Fair[Nothing]
case class One[+A](a: A) extends Fair[A]
- case class Choice[+A](a: A, rest: Fair[A]) extends Fair[A]
+ class Choice[+A](val a: A, expr: => Fair[A]) extends Fair[A] {
+ lazy val rest: Fair[A] = expr
+ }
+
+ object Choice {
+ def apply[A](a: A, expr: => Fair[A]): Choice[A] = new Choice(a, expr)
+
+ def unapply[A](s: Choice[A]): Some[(A, Fair[A])] = Some((s.a, s.rest))
+ }
class Incomplete[+A](expr: => Fair[A]) extends Fair[A] {
lazy val step: Fair[A] = expr
@@ -28,15 +36,35 @@ object Fair {
def guard(cond: Boolean): Fair[Unit] = if (cond) unit(()) else empty
def mplus[A](left: Fair[A], right: => Fair[A]): Fair[A] = left match {
- case Nil => Incomplete(right)
- case One(a) => Choice(a, right)
- case Choice(a, r) => Choice(a, mplus(right, r))
- case Incomplete(i) =>
+ case Nil => Incomplete(right)
+ case One(a) => Choice(a, right)
+ case c: Choice[A @unchecked] => Choice(c.a, mplus(right, c.rest))
+ case inc: Incomplete[A @unchecked] =>
right match {
- case Nil => Incomplete(i)
- case One(b) => Choice(b, i)
- case Choice(b, r2) => Choice(b, Incomplete(mplus(i, r2)))
- case Incomplete(j) => Incomplete(mplus(i, j))
+ case Nil => inc
+ case One(b) => Choice(b, inc.step)
+ case Choice(b, r2) => Choice(b, Incomplete(mplus(inc.step, r2)))
+ case Incomplete(j) => Incomplete(mplus(inc.step, j))
+ }
+ }
+
+ @annotation.tailrec
+ def runM[A](
+ maxDepth: Option[Int],
+ maxResults: Option[Int],
+ stream: Fair[A],
+ acc: List[A] = List.empty
+ ): List[A] = {
+ if (maxResults.exists(_ <= 0)) acc.reverse
+ else
+ stream match {
+ case Nil => acc.reverse
+ case One(a) => (a :: acc).reverse
+ case Choice(a, r) =>
+ runM(maxDepth, maxResults.map(_ - 1), r, a :: acc)
+ case Incomplete(i) =>
+ if (maxDepth.exists(_ <= 0)) acc.reverse
+ else runM(maxDepth.map(_ - 1), maxResults, i, acc)
}
}
@@ -50,8 +78,10 @@ object Fair {
def flatMap[A, B](fa: Fair[A])(f: A => Fair[B]): Fair[B] = fa match {
case Nil => Nil
case One(a) => f(a)
- case Choice(a, r) => combineK(f(a), Incomplete(flatMap(r)(f)))
- case Incomplete(i) => Incomplete(flatMap(i)(f))
+ case c: Choice[A @unchecked] =>
+ combineK(f(c.a), Incomplete(flatMap(c.rest)(f)))
+ case i: Incomplete[A @unchecked] =>
+ Incomplete(flatMap(i.step)(f))
}
def combineK[A](x: Fair[A], y: Fair[A]): Fair[A] = mplus(x, y)
diff --git a/core/src/main/scala/com/codiff/fairstream/FairT.scala b/core/src/main/scala/com/codiff/fairstream/FairT.scala
index c652c0f..d84234a 100644
--- a/core/src/main/scala/com/codiff/fairstream/FairT.scala
+++ b/core/src/main/scala/com/codiff/fairstream/FairT.scala
@@ -1,14 +1,30 @@
package com.codiff.fairstream
-import cats.{Alternative, Applicative, Monad}
+import cats.{Alternative, Applicative, Monad, ~>}
sealed trait FairE[M[_], A]
object FairE {
final case class Nil[M[_], A]() extends FairE[M, A]
final case class One[M[_], A](a: A) extends FairE[M, A]
- final case class Choice[M[_], A](a: A, rest: FairT[M, A]) extends FairE[M, A]
- final case class Incomplete[M[_], A](rest: FairT[M, A]) extends FairE[M, A]
+ class Choice[M[_], A](val a: A, expr: => FairT[M, A]) extends FairE[M, A] {
+ lazy val rest: FairT[M, A] = expr
+ }
+
+ object Choice {
+ def apply[M[_], A](a: A, expr: => FairT[M, A]): Choice[M, A] = new Choice(a, expr)
+
+ def unapply[M[_], A](s: Choice[M, A]): Some[(A, FairT[M, A])] = Some((s.a, s.rest))
+ }
+ class Incomplete[M[_], A](expr: => FairT[M, A]) extends FairE[M, A] {
+ lazy val rest: FairT[M, A] = expr
+ }
+
+ object Incomplete {
+ def apply[M[_], A](expr: => FairT[M, A]): Incomplete[M, A] = new Incomplete(expr)
+
+ def unapply[M[_], A](s: Incomplete[M, A]): Some[FairT[M, A]] = Some(s.rest)
+ }
}
final case class FairT[M[_], A](run: M[FairE[M, A]])
@@ -20,25 +36,36 @@ object FairT {
def unit[M[_], A](a: A)(implicit M: Applicative[M]): FairT[M, A] =
FairT(M.pure[FairE[M, A]](FairE.One(a)))
- def suspend[M[_], A](s: FairT[M, A])(implicit
+ def suspend[M[_], A](s: => FairT[M, A])(implicit
M: Applicative[M]
): FairT[M, A] =
FairT(M.pure[FairE[M, A]](FairE.Incomplete(s)))
+ def lift[M[_], A](ma: M[A])(implicit M: Monad[M]): FairT[M, A] =
+ FairT(M.map(ma)(a => FairE.One[M, A](a): FairE[M, A]))
+
+ def liftK[M[_]](implicit M: Monad[M]): M ~> FairT[M, *] =
+ new (M ~> FairT[M, *]) {
+ def apply[A](ma: M[A]): FairT[M, A] = lift(ma)
+ }
+
def mplus[M[_], A](left: FairT[M, A], right: => FairT[M, A])(implicit
M: Monad[M]
): FairT[M, A] = {
type E = FairE[M, A]
FairT(M.flatMap[E, E](left.run) {
- case FairE.Nil() => M.pure[E](FairE.Incomplete(right))
- case FairE.One(a) => M.pure[E](FairE.Choice(a, right))
- case FairE.Choice(a, r) => M.pure[E](FairE.Choice(a, mplus(right, r)))
- case FairE.Incomplete(i) =>
- M.map[E, E](right.run) {
- case FairE.Nil() => FairE.Incomplete(i)
- case FairE.One(b) => FairE.Choice(b, i)
- case FairE.Choice(b, r2) => FairE.Choice(b, mplus(i, r2))
- case FairE.Incomplete(j) => FairE.Incomplete(mplus(i, j))
+ case FairE.Nil() => M.pure[E](FairE.Incomplete(right))
+ case FairE.One(a) => M.pure[E](FairE.Choice(a, right))
+ case c: FairE.Choice[M, A] @unchecked =>
+ M.pure[E](FairE.Choice(c.a, mplus(right, c.rest)))
+ case inc: FairE.Incomplete[M, A] @unchecked =>
+ M.flatMap[E, E](right.run) {
+ case FairE.Nil() => M.pure[E](inc)
+ case FairE.One(b) => M.pure[E](FairE.Choice(b, inc.rest))
+ case rc: FairE.Choice[M, A] @unchecked =>
+ M.pure[E](FairE.Choice(rc.a, FairT(M.pure[E](FairE.Incomplete(mplus(inc.rest, rc.rest))))))
+ case rinc: FairE.Incomplete[M, A] @unchecked =>
+ M.pure[E](FairE.Incomplete(mplus(inc.rest, rinc.rest)))
}
})
}
@@ -48,13 +75,33 @@ object FairT {
)(f: A => FairT[M, B])(implicit M: Monad[M]): FairT[M, B] = {
type EB = FairE[M, B]
FairT(M.flatMap[FairE[M, A], EB](fa.run) {
- case FairE.Nil() => M.pure[EB](FairE.Nil())
- case FairE.One(a) => f(a).run
- case FairE.Choice(a, r) => mplus(f(a), suspend(flatMap(r)(f))).run
- case FairE.Incomplete(i) => M.pure[EB](FairE.Incomplete(flatMap(i)(f)))
+ case FairE.Nil() => M.pure[EB](FairE.Nil())
+ case FairE.One(a) => f(a).run
+ case c: FairE.Choice[M, A] @unchecked =>
+ mplus(f(c.a), suspend(flatMap(c.rest)(f))).run
+ case i: FairE.Incomplete[M, A] @unchecked =>
+ M.pure[EB](FairE.Incomplete(flatMap(i.rest)(f)))
})
}
+ def runM[M[_], A](
+ maxDepth: Option[Int],
+ maxResults: Option[Int],
+ stream: FairT[M, A]
+ )(implicit M: Monad[M]): M[List[A]] = {
+ if (maxResults.exists(_ <= 0)) M.pure(List.empty)
+ else
+ M.flatMap(stream.run) {
+ case FairE.Nil() => M.pure(List.empty)
+ case FairE.One(a) => M.pure(List(a))
+ case c: FairE.Choice[M, A] @unchecked =>
+ M.map(runM(maxDepth, maxResults.map(_ - 1), c.rest))(c.a :: _)
+ case inc: FairE.Incomplete[M, A] @unchecked =>
+ if (maxDepth.exists(_ <= 0)) M.pure(List.empty)
+ else runM(maxDepth.map(_ - 1), maxResults, inc.rest)
+ }
+ }
+
implicit def fairTMonad[M[_]: Monad]
: Monad[FairT[M, *]] with Alternative[FairT[M, *]] =
new Monad[FairT[M, *]] with Alternative[FairT[M, *]] {
@@ -65,11 +112,32 @@ object FairT {
def flatMap[A, B](fa: FairT[M, A])(f: A => FairT[M, B]): FairT[M, B] =
FairT.flatMap(fa)(f)
- def tailRecM[A, B](a: A)(f: A => FairT[M, Either[A, B]]): FairT[M, B] =
- flatMap(f(a)) {
- case Left(next) => tailRecM(next)(f)
- case Right(b) => FairT.unit(b)
+ def tailRecM[A, B](a: A)(f: A => FairT[M, Either[A, B]]): FairT[M, B] = {
+ val MM = Monad[M]
+ type E = FairE[M, B]
+ val cont: Either[A, B] => FairT[M, B] = {
+ case Left(a) => tailRecM(a)(f)
+ case Right(b) => FairT.unit(b)
}
+ FairT[M, B](MM.tailRecM[A, E](a) { a =>
+ MM.map[FairE[M, Either[A, B]], Either[A, E]](f(a).run) {
+ case FairE.Nil() => Right(FairE.Nil())
+ case FairE.One(Left(a)) => Left(a)
+ case FairE.One(Right(b)) => Right(FairE.One(b))
+ case c: FairE.Choice[M, Either[A, B]] @unchecked =>
+ val rest: FairT[M, B] = FairT.flatMap[M, Either[A, B], B](c.rest)(cont)
+ c.a match {
+ case Right(b) => Right(FairE.Choice(b, rest))
+ case Left(a) =>
+ Right(FairE.Incomplete(mplus[M, B](tailRecM(a)(f), rest)))
+ }
+ case inc: FairE.Incomplete[M, Either[A, B]] @unchecked =>
+ Right(
+ FairE.Incomplete(FairT.flatMap[M, Either[A, B], B](inc.rest)(cont))
+ )
+ }
+ })
+ }
def combineK[A](x: FairT[M, A], y: FairT[M, A]): FairT[M, A] = mplus(x, y)
}
diff --git a/core/src/test/scala/com/codiff/fairstream/PythagoreanSuite.scala b/core/src/test/scala/com/codiff/fairstream/PythagoreanSuite.scala
new file mode 100644
index 0000000..8ab73ba
--- /dev/null
+++ b/core/src/test/scala/com/codiff/fairstream/PythagoreanSuite.scala
@@ -0,0 +1,99 @@
+package com.codiff.fairstream
+
+import cats.Eval
+import cats.syntax.all._
+import munit.FunSuite
+
+class PythagoreanSuite extends FunSuite {
+
+ def isPythagorean(t: (Int, Int, Int)): Boolean = {
+ val (i, j, k) = t
+ i * i + j * j == k * k
+ }
+
+ // -- Fair tests (FBackTrack.hs) --
+
+ test("Fair: pythagorean triples (generic MonadPlus style)") {
+ import Fair._
+
+ lazy val number: Fair[Int] = mplus(unit(0), number.map(_ + 1))
+
+ val triples = for {
+ i <- number
+ _ <- guard(i > 0)
+ j <- number
+ _ <- guard(j > 0)
+ k <- number
+ _ <- guard(k > 0)
+ _ <- guard(i * i + j * j == k * k)
+ } yield (i, j, k)
+
+ val results = Fair.runM(None, Some(7), triples)
+ assertEquals(results.length, 7)
+ assert(results.forall(isPythagorean))
+ }
+
+ test("Fair: pythagorean triples with left recursion") {
+ import Fair._
+
+ lazy val number: Fair[Int] = mplus((Incomplete(number): Fair[Int]).map(_ + 1), unit(0))
+
+ val triples = for {
+ i <- number
+ j <- number
+ k <- number
+ _ <- guard(i * i + j * j == k * k)
+ } yield (i, j, k)
+
+ val results = Fair.runM(None, Some(27), triples)
+ assertEquals(results.length, 27)
+ assert(results.forall(isPythagorean))
+ }
+
+ // -- FairT tests (FBackTrackT.hs) --
+ // Using Eval as M because:
+ // - Id is strict with no trampolining → stack overflow on deep Incomplete chains
+ // - Eval provides stack-safe trampolined flatMap with minimal overhead
+ // - IO works but has too much per-step overhead for large searches
+
+ def guardF(cond: Boolean): FairT[Eval, Unit] =
+ if (cond) FairT.unit[Eval, Unit](()) else FairT.empty[Eval, Unit]
+
+ test("FairT[Eval]: pythagorean triples (generic MonadPlus style)") {
+ lazy val number: FairT[Eval, Int] =
+ FairT.mplus[Eval, Int](FairT.unit[Eval, Int](0), number.map(_ + 1))
+
+ val triples: FairT[Eval, (Int, Int, Int)] = for {
+ i <- number
+ _ <- guardF(i > 0)
+ j <- number
+ _ <- guardF(j > 0)
+ k <- number
+ _ <- guardF(k > 0)
+ _ <- guardF(i * i + j * j == k * k)
+ } yield (i, j, k)
+
+ val results = FairT.runM[Eval, (Int, Int, Int)](None, Some(7), triples).value
+ assertEquals(results.length, 7)
+ assert(results.forall(isPythagorean))
+ }
+
+ test("FairT[Eval]: pythagorean triples with left recursion") {
+ lazy val number: FairT[Eval, Int] =
+ FairT.mplus[Eval, Int](
+ FairT.suspend[Eval, Int](number).map(_ + 1),
+ FairT.unit[Eval, Int](0)
+ )
+
+ val triples: FairT[Eval, (Int, Int, Int)] = for {
+ i <- number
+ j <- number
+ k <- number
+ _ <- guardF(i * i + j * j == k * k)
+ } yield (i, j, k)
+
+ val results = FairT.runM[Eval, (Int, Int, Int)](None, Some(27), triples).value
+ assertEquals(results.length, 27)
+ assert(results.forall(isPythagorean))
+ }
+}