Adapted from AWS samples for Ray.

The idea is to assign host algo-1 as the head node (master host) of the Ray cluster. All other nodes (algo-2, algo-3 etc) will be worker nodes.

import subprocess
import os
import time
import ray
import socket
import json
import sys
 
class RayHelper():
	def __init__(
		self,
		ray_port:str="9339",
		redis_pass:str="redis_password"
	):
		self.ray_port = ray_port
		self.redis_pass = redis_pass
		self.resource_config = self.get_resource_config()
		# master_host is algo-1
		self.master_host = self.resource_config["hosts"][0]
		self.n_hosts = len(self.resource_config["hosts"])
 
	@staticmethod
	def get_resource_config():
		return dict(
			current_host = os.environ.get("SM_CURRENT_HOST"),
			hosts = json.loads(os.environ.get("SM_HOSTS"))
		)
 
	def _get_ip_from_host(self):
		ip_wait_time = 200
		counter = 0
		ip = ""
 
		while counter < ip_wait_time and ip == "":
			try:
				ip = socket.gethostbyname(self.master_host)
				break
			except:
				counter += 1
				time.sleep(1)
 
		if counter == ip_wait_time and ip == "":
			raise Exception(
				"Exceeded max wait time of {}s for hostname resolution".format(ip_wait_time)
			)
		return ip 
 
	def start_ray(self):
		master_ip = self._get_ip_from_host()
		if self.resource_config["current_host"] == self.master_host:
			output = subprocess.run(
				[
					'ray',
					'start',
					'--head',
					'-vvv',
					'--port',
					self.ray_port,
					'--redis-password',
					self.redis_pass,
					'--include-dashboard',
					'false'
				],
				stdout=subprocess.PIPE,
			)
		print(output.stdout.decode("utf-8"))
		ray.init(address="auto", include_dashboard=False)
		self._wait_for_workers()
		print("All workers present and accounted for")
 
		print("Available resources:")
		print(ray.available_resources())
		
		print("\nCluster resources:")
		print(ray.cluster_resources())
 
	else:
		time.sleep(10)
		output = subprocess.run(
			[
				'ray',
				'start',
				f"--address={master_ip}:{self.ray_port}",
				'--redis-password',
				self.redis_pass,
				"--block"
			],
			stdout=subprocess.PIPE,
		)
		print(output.stdout.decode("utf-8"))
		sys.exit(0)
 
 
def _wait_for_workers(self, timeout=60):
	print(f"Waiting up to {timeout} seconds for {self.n_hosts} nodes to join")
 
	while len(ray.nodes()) < self.n_hosts:
		print(f"{len(ray.nodes())} nodes connected to cluster")
		time.sleep(5)
		timeout-=5
		if timeout==0:
			raise Exception("Max timeout for nodes to join exceeded")

Then, instead of ray.init(), run ray_helper = RayHelper(); ray_helper.start_ray().

This can be used to batch process model runs or another embarrassingly parallel process.