aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--build.sbt15
-rw-r--r--fs2/src/main/scala/com/codiff/fairstream/fs2/conversions.scala34
-rw-r--r--fs2/src/main/scala/com/codiff/fairstream/fs2/syntax.scala18
-rw-r--r--fs2/src/test/scala/com/codiff/fairstream/fs2/FairFs2Suite.scala155
4 files changed, 221 insertions, 1 deletions
diff --git a/build.sbt b/build.sbt
index e92779f..5f829ee 100644
--- a/build.sbt
+++ b/build.sbt
@@ -17,7 +17,7 @@ val Scala213 = "2.13.18"
ThisBuild / crossScalaVersions := Seq(Scala213, "3.3.7")
ThisBuild / scalaVersion := Scala213 // the default Scala
-lazy val root = tlCrossRootProject.aggregate(core)
+lazy val root = tlCrossRootProject.aggregate(core, fs2)
lazy val core = crossProject(JVMPlatform, JSPlatform)
.crossType(CrossType.Pure)
@@ -32,4 +32,17 @@ lazy val core = crossProject(JVMPlatform, JSPlatform)
)
)
+lazy val fs2 = crossProject(JVMPlatform, JSPlatform)
+ .crossType(CrossType.Pure)
+ .in(file("fs2"))
+ .dependsOn(core)
+ .settings(
+ name := "fairstream-fs2",
+ libraryDependencies ++= Seq(
+ "co.fs2" %%% "fs2-core" % "3.12.2",
+ "org.scalameta" %%% "munit" % "1.2.2" % Test,
+ "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test
+ )
+ )
+
lazy val docs = project.in(file("site")).enablePlugins(TypelevelSitePlugin)
diff --git a/fs2/src/main/scala/com/codiff/fairstream/fs2/conversions.scala b/fs2/src/main/scala/com/codiff/fairstream/fs2/conversions.scala
new file mode 100644
index 0000000..4083c31
--- /dev/null
+++ b/fs2/src/main/scala/com/codiff/fairstream/fs2/conversions.scala
@@ -0,0 +1,34 @@
+package com.codiff.fairstream
+package fs2
+
+import cats.Monad
+import _root_.fs2.{Pull, Pure, Stream}
+
+object conversions {
+
+ def fairToStream[A](fair: Fair[A]): Stream[Pure, A] = {
+ def go(f: Fair[A]): Pull[Pure, A, Unit] = f match {
+ case Fair.Nil => Pull.done
+ case Fair.One(a) => Pull.output1(a)
+ case c: Fair.Choice[A @unchecked] => Pull.output1(c.a) >> go(c.rest)
+ case i: Fair.Incomplete[A @unchecked] => go(i.step)
+ }
+ go(fair).stream
+ }
+
+ def fairTToStream[F[_], A](fairT: FairT[F, A])(implicit
+ F: Monad[F]
+ ): Stream[F, A] = {
+ def go(ft: FairT[F, A]): Pull[F, A, Unit] =
+ Pull.eval(ft.run).flatMap {
+ case FairE.Nil() => Pull.done
+ case FairE.One(a) => Pull.output1(a)
+ case c: FairE.Choice[F, A] @unchecked =>
+ Pull.output1(c.a) >> go(c.rest)
+ case inc: FairE.Incomplete[F, A] @unchecked =>
+ go(inc.rest)
+ }
+ go(fairT).stream
+ }
+
+}
diff --git a/fs2/src/main/scala/com/codiff/fairstream/fs2/syntax.scala b/fs2/src/main/scala/com/codiff/fairstream/fs2/syntax.scala
new file mode 100644
index 0000000..1c7aba7
--- /dev/null
+++ b/fs2/src/main/scala/com/codiff/fairstream/fs2/syntax.scala
@@ -0,0 +1,18 @@
+package com.codiff.fairstream
+package fs2
+
+import cats.Monad
+import _root_.fs2.{Pure, Stream}
+
+object syntax {
+
+ implicit class FairFs2Ops[A](val fair: Fair[A]) extends AnyVal {
+ def toFs2: Stream[Pure, A] = conversions.fairToStream(fair)
+ }
+
+ implicit class FairTFs2Ops[F[_], A](val fairT: FairT[F, A]) extends AnyVal {
+ def toFs2(implicit F: Monad[F]): Stream[F, A] =
+ conversions.fairTToStream(fairT)
+ }
+
+}
diff --git a/fs2/src/test/scala/com/codiff/fairstream/fs2/FairFs2Suite.scala b/fs2/src/test/scala/com/codiff/fairstream/fs2/FairFs2Suite.scala
new file mode 100644
index 0000000..cbb2975
--- /dev/null
+++ b/fs2/src/test/scala/com/codiff/fairstream/fs2/FairFs2Suite.scala
@@ -0,0 +1,155 @@
+package com.codiff.fairstream
+package fs2
+
+import scala.concurrent.duration._
+
+import cats.effect.IO
+import cats.syntax.all._
+import munit.CatsEffectSuite
+
+import syntax._
+
+class FairFs2Suite extends CatsEffectSuite {
+
+ test("Fair.toFs2: empty stream") {
+ val result = Fair.empty[Int].toFs2.toList
+ assertEquals(result, List.empty[Int])
+ }
+
+ test("Fair.toFs2: single element") {
+ val result = Fair.unit(42).toFs2.toList
+ assertEquals(result, List(42))
+ }
+
+ test("Fair.toFs2: finite stream") {
+ import Fair._
+ val stream = mplus(unit(1), mplus(unit(2), unit(3)))
+ val result = stream.toFs2.toList
+ val expected = Fair.runM(None, None, stream)
+ assertEquals(result, expected)
+ }
+
+ test("Fair.toFs2: infinite stream (take N)") {
+ import Fair._
+ lazy val number: Fair[Int] = mplus(unit(0), number.map(_ + 1))
+ val result = number.toFs2.take(20).toList
+ val expected = Fair.runM(None, Some(20), number)
+ assertEquals(result, expected)
+ }
+
+ test("Fair.toFs2: pythagorean triples match runM output") {
+ import Fair._
+ lazy val number: Fair[Int] = mplus(unit(0), number.map(_ + 1))
+
+ val triples = for {
+ i <- number
+ _ <- guard(i > 0)
+ j <- number
+ _ <- guard(j > 0)
+ k <- number
+ _ <- guard(k > 0)
+ _ <- guard(i * i + j * j == k * k)
+ } yield (i, j, k)
+
+ val n = 7
+ val result = triples.toFs2.take(n.toLong).toList
+ val expected = Fair.runM(None, Some(n), triples)
+ assertEquals(result, expected)
+ }
+
+ test("FairT[IO].toFs2: empty stream") {
+ FairT.empty[IO, Int].toFs2.compile.toList.map { result =>
+ assertEquals(result, List.empty[Int])
+ }
+ }
+
+ test("FairT[IO].toFs2: single element") {
+ FairT.unit[IO, Int](42).toFs2.compile.toList.map { result =>
+ assertEquals(result, List(42))
+ }
+ }
+
+ test("FairT[IO].toFs2: finite stream") {
+ val stream = FairT.mplus[IO, Int](
+ FairT.unit[IO, Int](1),
+ FairT.mplus[IO, Int](
+ FairT.unit[IO, Int](2),
+ FairT.unit[IO, Int](3)
+ )
+ )
+ for {
+ result <- stream.toFs2.compile.toList
+ expected <- FairT.runM[IO, Int](None, None, stream)
+ } yield assertEquals(result, expected)
+ }
+
+ test("FairT[IO].toFs2: pythagorean triples match runM output") {
+ def guardF(cond: Boolean): FairT[IO, Unit] =
+ if (cond) FairT.unit[IO, Unit](()) else FairT.empty[IO, Unit]
+
+ lazy val number: FairT[IO, Int] =
+ FairT.mplus[IO, Int](FairT.unit[IO, Int](0), number.map(_ + 1))
+
+ val triples: FairT[IO, (Int, Int, Int)] = for {
+ i <- number
+ _ <- guardF(i > 0)
+ j <- number
+ _ <- guardF(j > 0)
+ k <- number
+ _ <- guardF(k > 0)
+ _ <- guardF(i * i + j * j == k * k)
+ } yield (i, j, k)
+
+ val n = 7
+ for {
+ result <- triples.toFs2.take(n.toLong).compile.toList
+ expected <- FairT.runM[IO, (Int, Int, Int)](None, Some(n), triples)
+ } yield assertEquals(result, expected)
+ }
+
+ // -- Plain fs2.Stream vs Fair interleaving --
+
+ test("plain fs2.Stream cannot find pythagorean triples (depth-first gets stuck)") {
+ // Plain fs2.Stream uses depth-first (sequential) flatMap: for i=1, j=1 it
+ // tries k=1,2,3,... forever, never advancing j or i. So it cannot produce
+ // even a single triple from an infinite number stream within a budget.
+ val number: _root_.fs2.Stream[IO, Int] = _root_.fs2.Stream.iterate(1)(_ + 1)
+
+ val triples = for {
+ i <- number
+ j <- number
+ k <- number
+ if i * i + j * j == k * k
+ } yield (i, j, k)
+
+ // Give it a generous timeout — still finds nothing.
+ triples
+ .take(1)
+ .interruptAfter(3.seconds)
+ .compile
+ .toList
+ .map { result =>
+ assertEquals(result, List.empty[(Int, Int, Int)])
+ }
+ }
+
+ test("Fair.toFs2 finds pythagorean triples thanks to fair interleaving") {
+ import Fair._
+ lazy val number: Fair[Int] = mplus(unit(0), number.map(_ + 1))
+
+ val triples = for {
+ i <- number
+ _ <- guard(i > 0)
+ j <- number
+ _ <- guard(j > 0)
+ k <- number
+ _ <- guard(k > 0)
+ _ <- guard(i * i + j * j == k * k)
+ } yield (i, j, k)
+
+ val result = triples.toFs2.take(7).toList
+ assertEquals(result.length, 7)
+ assert(result.forall { case (i, j, k) => i * i + j * j == k * k })
+ }
+
+}