How to test a Scala Play Framework websocket?

Adapted to Playframework 2.7

import java.util.concurrent.ExecutionException
import java.util.function.Consumer

import com.typesafe.scalalogging.StrictLogging
import play.shaded.ahc.org.asynchttpclient.AsyncHttpClient
import play.shaded.ahc.org.asynchttpclient.netty.ws.NettyWebSocket
import play.shaded.ahc.org.asynchttpclient.ws.{WebSocket, WebSocketListener, WebSocketUpgradeHandler}

import scala.compat.java8.FutureConverters
import scala.concurrent.Future

class LoggingListener(onMessageCallback: Consumer[String]) extends WebSocketListener with StrictLogging {

  override def onOpen(websocket: WebSocket): Unit = {
    logger.info("onClose: ")
    websocket.sendTextFrame("hello")
  }

  override def onClose(webSocket: WebSocket, i: Int, s: String): Unit =
    logger.info("onClose: ")

  override def onError(t: Throwable): Unit =
    logger.error("onError: ", t);

  override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int): Unit = {
    logger.debug(s"$payload $finalFragment $rsv")
    onMessageCallback.accept(payload)
  }

}

class WebSocketClient(client: AsyncHttpClient) {

  @throws[ExecutionException]
  @throws[InterruptedException]
  def call(url: String, origin: String, listener: WebSocketListener): Future[NettyWebSocket] = {
    val requestBuilder   = client.prepareGet(url).addHeader("Origin", origin)
    val handler          = new WebSocketUpgradeHandler.Builder().addWebSocketListener(listener).build
    val listenableFuture = requestBuilder.execute(handler)
    FutureConverters.toScala(listenableFuture.toCompletableFuture)
  }
}

And in test:

  val myPublicAddress = s"localhost:$port"
  val serverURL       = s"ws://$myPublicAddress/api/alarm/ws"

  val asyncHttpClient            = client.underlying[AsyncHttpClient]
  val webSocketClient            = new WebSocketClient(asyncHttpClient)
  val origin                     = "ws://example.com/ws"
  val consumer: Consumer[String] = (message: String) => logger.debug(message)
  val listener                   = new LoggingListener(consumer)
  val f                          = webSocketClient.call(serverURL, origin, listener)
  Await.result(f, atMost = 1000.millis)

Play 2.6

I followed this Example: play-scala-websocket-example

Main steps:

Create or provide a WebSocketClient that you can use in your
tests.

Create the client:

val asyncHttpClient: AsyncHttpClient = wsClient.underlying[AsyncHttpClient]
val webSocketClient = new WebSocketClient(asyncHttpClient)

Connect to the serverURL:

val listener = new WebSocketClient.LoggingListener(message => queue.put(message))
val completionStage = webSocketClient.call(serverURL, origin, listener)
val f = FutureConverters.toScala(completionStage)

Test the Messages sent by the Server:

whenReady(f, timeout = Timeout(1.second)) { webSocket =>
  await().until(() => webSocket.isOpen && queue.peek() != null)

  checkMsg1(queue.take())
  checkMsg2(queue.take())
  assert(queue.isEmpty)
}

For example, like:

  private def checkMsg1(msg: String) {
    val json: JsValue = Json.parse(msg)
    json.validate[AdapterMsg] match {
      case JsSuccess(AdapterNotRunning(None), _) => // ok
      case other => fail(s"Unexpected result: $other")
    }
  }

The whole example can be found here: scala-adapters (JobCockpitControllerSpec)


This is a complete example which uses the Akka Websocket Client to test a Websocket controller. There is some custom code, but it shows multiple test scenarios. This works for Play 2.7.

package controllers

import java.util.concurrent.{ LinkedBlockingDeque, TimeUnit }

import actors.WSBridge
import akka.Done
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.headers.{ Origin, RawHeader }
import akka.http.scaladsl.model.ws.{ BinaryMessage, Message, TextMessage, WebSocketRequest }
import akka.http.scaladsl.model.{ HttpResponse, StatusCodes, Uri }
import akka.stream.scaladsl.{ Flow, Keep, Sink, Source, SourceQueueWithComplete }
import akka.stream.{ ActorMaterializer, OverflowStrategy }
import models.WSTopic
import org.specs2.matcher.JsonMatchers
import play.api.Logging
import play.api.inject.guice.GuiceApplicationBuilder
import play.api.test._

import scala.collection.immutable.Seq
import scala.concurrent.Future

/**
 * Test case for the [[WSController]] actor.
 */
class WSControllerSpec extends ForServer with WSControllerSpecContext with JsonMatchers {

  "The `socket` method" should {
    "return a 403 status code if the origin doesn't match" >> { implicit rs: RunningServer =>
      val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint)))

      maybeSocket must beLeft[HttpResponse].like { case response =>
        response.status must be equalTo StatusCodes.Forbidden
      }
    }

    "return a 400 status code if the topic cannot be found" >> { implicit rs: RunningServer =>
      val headers = Seq(Origin("http://localhost:9443"))
      val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint, headers)))

      maybeSocket must beLeft[HttpResponse].like { case response =>
        response.status must be equalTo StatusCodes.BadRequest
      }
    }

    "return a 400 status code if the topic syntax isn't valid in query param" >> { implicit rs: RunningServer =>
      val headers = Seq(Origin("http://localhost:9443"))
      val request = WebSocketRequest(endpoint.withRawQueryString("?topic=."), headers)
      val maybeSocket = await(websocketClient.connect(request))

      maybeSocket must beLeft[HttpResponse].like { case response =>
        response.status must be equalTo StatusCodes.BadRequest
      }
    }

    "return a 400 status code if the topic syntax isn't valid in header param" >> { implicit rs: RunningServer =>
      val headers = Seq(Origin("http://localhost:9443"), RawHeader("X-TOPIC", "."))
      val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint, headers)))

      maybeSocket must beLeft[HttpResponse].like { case response =>
        response.status must be equalTo StatusCodes.BadRequest
      }
    }

    "receive an acknowledge message when connecting to a topic via query param" >> { implicit rs: RunningServer =>
      val headers = Seq(Origin("http://localhost:9443"))
      val request = WebSocketRequest(endpoint.withRawQueryString("topic=%2Fflowers%2Frose"), headers)
      val maybeSocket = await(websocketClient.connect(request))

      maybeSocket must beRight[(SourceQueue, MessageQueue)].like { case (_, messages) =>
        messages.poll(1000, TimeUnit.MILLISECONDS) must be equalTo
          WSBridge.Ack(WSTopic("/flowers/rose")).message.toJson.toString()
      }
    }

    "receive an acknowledge message when connecting to a topic via query param" >> { implicit rs: RunningServer =>
      val headers = Seq(Origin("http://localhost:9443"), RawHeader("X-TOPIC", "/flowers/tulip"))
      val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint, headers)))

      maybeSocket must beRight[(SourceQueue, MessageQueue)].like { case (_, messages) =>
        messages.poll(1000, TimeUnit.MILLISECONDS) must be equalTo
          WSBridge.Ack(WSTopic("/flowers/tulip")).message.toJson.toString()
      }
    }

    "receive a pong message when sending a ping" >> { implicit rs: RunningServer =>
      val headers = Seq(Origin("http://localhost:9443"), RawHeader("X-TOPIC", "/flowers/tulip"))
      val maybeSocket = await(websocketClient.connect(WebSocketRequest(endpoint, headers)))

      maybeSocket must beRight[(SourceQueue, MessageQueue)].like { case (queue, messages) =>
        queue.offer(WSBridge.Ping.toJson.toString())

        messages.poll(1000, TimeUnit.MILLISECONDS) must be equalTo
          WSBridge.Ack(WSTopic("/flowers/tulip")).message.toJson.toString()

        messages.poll(1000, TimeUnit.MILLISECONDS) must be equalTo
          WSBridge.Pong.toJson.toString()
      }
    }
  }
}

/**
 * The context for the [[WSControllerSpec]].
 */
trait WSControllerSpecContext extends ForServer with PlaySpecification with ApplicationFactories {

  type SourceQueue = SourceQueueWithComplete[String]
  type MessageQueue = LinkedBlockingDeque[String]

  /**
   * Provides the application factory.
   */
  protected def applicationFactory: ApplicationFactory = withGuiceApp(GuiceApplicationBuilder())

  /**
   * Gets the WebSocket endpoint.
   *
   * @param rs The running server.
   * @return The WebSocket endpoint.
   */
  protected def endpoint(implicit rs: RunningServer): Uri =
    Uri(rs.endpoints.httpEndpoint.get.pathUrl("/ws").replace("http://", "ws://"))

  /**
   * Provides an instance of the WebSocket client.
   *
   * This should be a method to return a fresh client for every test.
   */
  protected def websocketClient = new AkkaWebSocketClient

  /**
   * An Akka WebSocket client that is optimized for testing.
   */
  class AkkaWebSocketClient extends Logging {

    /**
     * The queue of received messages.
     */
    private val messageQueue = new LinkedBlockingDeque[String]()

    /**
     * Connect to the WebSocket.
     *
     * @param wsRequest The WebSocket request instance.
     * @return Either an [[HttpResponse]] if the upgrade process wasn't successful or a source and a message queue
     *         to which new messages may be offered.
     */
    def connect(wsRequest: WebSocketRequest): Future[Either[HttpResponse, (SourceQueue, MessageQueue)]] = {
      implicit val system: ActorSystem = ActorSystem()
      implicit val materializer: ActorMaterializer = ActorMaterializer()
      import system.dispatcher

      // Store each incoming message in the messages queue
      val incoming: Sink[Message, Future[Done]] = Sink.foreach {
        case TextMessage.Strict(s)     => messageQueue.offer(s)
        case TextMessage.Streamed(s)   => s.runFold("")(_ + _).foreach(messageQueue.offer)
        case BinaryMessage.Strict(s)   => messageQueue.offer(s.utf8String)
        case BinaryMessage.Streamed(s) => s.runFold("")(_ + _.utf8String).foreach(messageQueue.offer)
      }

      // Out source is a queue to which we can offer messages that will be sent to the WebSocket server.
      // All offered messages will be transformed into WebSocket messages.
      val sourceQueue = Source.queue[String](Int.MaxValue, OverflowStrategy.backpressure)
        .map { msg => TextMessage.Strict(msg) }
      val (sourceMat, source) = sourceQueue.preMaterialize()

      // The outgoing flow sends all messages which are offered to the queue (our stream source) to the WebSocket
      // server.
      val flow: Flow[Message, Message, Future[Done]] = Flow.fromSinkAndSourceMat(incoming, source)(Keep.left)

      // UpgradeResponse is a Future[WebSocketUpgradeResponse] that completes or fails when the connection succeeds
      // or fails and closed is a Future[Done] representing the stream completion from above
      val (upgradeResponse, closed) = Http().singleWebSocketRequest(wsRequest, flow)
      closed.foreach(_ => logger.info("Channel closed"))
      upgradeResponse.map { upgrade =>
        if (upgrade.response.status == StatusCodes.SwitchingProtocols) {
          Right((sourceMat, messageQueue))
        } else {
          Left(upgrade.response)
        }
      }
    }
  }
}