Back

Explore Courses Blog Tutorials Interview Questions
0 votes
2 views
in Big Data Hadoop & Spark by (11.4k points)

I'm running a Spark Streaming task in a cluster using YARN. Each node in the cluster runs multiple spark workers. Before the streaming starts, I want to execute a "setup" function on all workers on all nodes in the cluster.

The streaming task classifies incoming messages as spam or not spam, but before it can do that it needs to download the latest pre-trained models from HDFS to local disks, like this pseudo code example:

def fetch_models():
    if
hadoop.version > local.version:
        hadoop.download()


I've seen the following examples here on SO:

sc.parallelize().map(fetch_models)


But in Spark 1.6 parallelize() requires some data to be used, like this shitty workaround I'm doing now:

sc.parallelize(range(1, 1000)).map(fetch_models)


Just to be fairly sure that the function is run on ALL workers I set the range to 1000. I also don't exactly know how many workers are in the cluster when running.

I've read the programming documentation and googled relentlessly but I can't seem to find any way to actually just distribute anything to all workers without any data.

After this initialization phase is done, the streaming task is, as usual, operating on incoming data from Kafka.

The way I'm using the models is by running a function similar to this:

spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS)
stream.union(*create_kafka_streams())\
    .repartition(spark_partitions)\
    .foreachRDD(lambda
rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition)))


Theoretically, I could check whether or not the models are up to date in the on_partition function, though it would be really wasteful to do this on each batch. I'd like to do it before Spark starts retrieving batches from Kafka, since the downloading from HDFS can take a couple of minutes...

To be clear: it's not an issue on how to distribute the files or how to load them, it's about how to run an arbitrary method on all workers without operating on any data.

To clarify what actually loading models mean currently:

def on_partition(config, partition):
    if not MyClassifier.is_loaded():
        MyClassifier.load_models(config)

    handle_partition(config, partition)


While MyClassifier is something like this:

class MyClassifier:
    clf = None

    @staticmethod
    def is_loaded():
        return MyClassifier.clf is not None

    @staticmethod
    def load_models(config):
        MyClassifier.clf = load_from_file(config)


Static methods since PySpark doesn't seem to be able to serialize classes with non-static methods (the state of the class is irrelevant with relation to another worker). Here we only have to call load_models() once, and on all future batches MyClassifier.clf will be set. This is something that should really not be done for each batch, it's a one time thing. Same with downloading the files from HDFS using fetch_models().

1 Answer

0 votes
by (32.3k points)

I would suggest you to use SparkFiles mechanism. I think this will be the simplest approach to distribute a file between worker machines:

some_path = ...  # local file, a file in DFS, an HTTP, HTTPS or FTP URI.

sc.addFile(some_path)

And then just retrieve it on the workers using SparkFiles.get and standard IO tools:

from pyspark import SparkFiles

with open(SparkFiles.get(some_path)) as fw:

    ... # Do something

If you want to make sure that the model is actually loaded the simplest approach is to load on module import. Assuming config can be used to retrieve the model path:

  • model.py:

           from pyspark import SparkFiles

           config = ...

           class MyClassifier:

           clf = None

           @staticmethod

                 def is_loaded():

                 return MyClassifier.clf is not None

           @staticmethod

                  def load_models(config):

                  path = SparkFiles.get(config.get("model_file"))

                  MyClassifier.clf = load_from_file(path)

           # Executed once per interpreter 

           MyClassifier.load_models(config)  

  • main.py:

            from pyspark import SparkContext

            config = ...

            sc = SparkContext("local", "foo")

            # Executed before StreamingContext starts

            sc.addFile(config.get("model_file"))

            sc.addPyFile("model.py")

            import model

            ssc = ...

            stream = ...

            stream.map(model.MyClassifier.do_something).pprint()

            ssc.start()

            ssc.awaitTermination()

Browse Categories

...