Tensor parallel inference with deepspeed on ray

I am trying tensor parallel with deepspeed on ray. I constructed class DeepspeedTPWorker to handle one part of pre-trained model and class DeepspeedPredictor to manage a group of workers who load one complete model in tensor parallelism. Then I use ActorPool to manage some DeepspeedPredictor actors.

I ran my script as follow on a machine with 16 cores and 2 A40s with 46G memory each.

# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.4
#   kernelspec:
#     display_name: .venv
#     language: python
#     name: python3
# ---

# +
import ray

ray.init()
ray.available_resources()
# -
import torch

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" 

# +
from transformers import AutoTokenizer


class tokenize:
    def __init__(self, model_name=model_name):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def __call__(self, text):
        batch_tokens = self.tokenizer(
            text["begginnings"].tolist(),
            return_tensors="pt",
            max_length=128,
            padding="max_length",
            truncation=True,
        )
        return dict(batch_tokens)



# +
import pandas as pd
from ray import data

begginnings_all = [
    "The King is dead. Long live the Queen.",
    "Once there were four children whose names were Peter, Susan, Edmund, and Lucy.",
    "The story so far: in the beginning, the universe was created.",
    "It was a bright cold day in April, and the clocks were striking thirteen.",
    "It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife.",
    "The sweat wis lashing oafay Sick Boy; he wis trembling.",
    "124 was spiteful. Full of Baby's venom.",
    "As Gregor Samsa awoke one morning from uneasy dreams he found himself transformed in his bed into a gigantic insect.",
    "I write this sitting in the kitchen sink.",
    "We were somewhere around Barstow on the edge of the desert when the drugs began to take hold.",
] * 10

dataset = data.from_pandas(pd.DataFrame({"begginnings": begginnings_all}))
ray_dataset = dataset.map_batches(tokenize, concurrency=2, batch_size=10)
print(ray_dataset)
# -

from contextlib import closing
import socket


# +
import ray.data

from transformers import AutoModelForCausalLM

import torch
import torch.distributed as dist



from typing import Dict, Any

import os

import pandas as pd



@ray.remote(num_cpus=2, num_gpus=1)
class DeepspeedTPWorker:
    def __init__(
        self,
        model_rank: int,
        local_rank: int,
        local_world_size: int,
        ip_address :str,
        port: int,
        model ,
        ds_config=None,
        dtype=torch.float16,
    ):
        
        print(f"is cuda available: {torch.cuda.is_available()}")
        
        self.ds_config = ds_config

        self.model_rank = model_rank
        self.local_rank = local_rank
        self.local_world_size = local_world_size

        self.ip_address = ip_address
        self.port = port
        
        self.model = model
        self.ds_config = ds_config
        self.dtype = dtype

        print(
            f"rank {self.local_rank} assigned {torch.cuda.device_count()} CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}"
        )

        dist.init_process_group(
            backend="nccl",
            init_method=f"tcp://{self.ip_address}:{self.port}",  
            rank=self.local_rank,
            world_size=self.local_world_size,
        )

        
    def load_ds_engine(self):
        import deepspeed
        
        deepspeed.init_distributed("nccl")
        self.model = deepspeed.init_inference(self.model, self.ds_config)
        
        # with open(f"layers-{self.local_rank}","w") as f:
        #     for name, param in self.model.named_parameters():
        #         f.write(f"Rank: {self.local_rank}, Parameter: {name}, Mean: {param.mean().item()}\n")

    def _predict_pandas(self, batch: Dict[str, Any]) -> pd.DataFrame:
        with torch.no_grad():
            out_tokens = self.model.generate(**batch)
            out_tokens = out_tokens.to("cpu")

        return pd.DataFrame(
            {
                "generated_tokens": out_tokens,
                "decoded_texts": self.tokenizer.batch_decode(
                    out_tokens, skip_special_tokens=True
                ),
            }
        )

# +

@ray.remote(num_cpus=2)
class DeepspeedPredictor:
    def __init__(self, model, num_workers_per_model, model_rank) -> None:
        self.model = model
        self.num_workers_per_model = num_workers_per_model
        self.ds_config = {
            "tensor_parallel": {
                    "enabled": True,
                    "tp_size": num_workers_per_model,
            },
        }

        self.comm_port = self._find_free_port()
        self.comm_master_ip = ray.util.get_node_ip_address()
        
        self.workers = [
            DeepspeedTPWorker.remote(model_rank, i, self.num_workers_per_model, self.comm_master_ip, self.comm_port, model_ref, self.ds_config)
            for i in range(self.num_workers_per_model)
        ]
        
        self.load_model()
        
    def load_model(self):
        
        ray.get([p.load_ds_engine.remote() for p in self.workers])
        
    
    def _find_free_port(self) -> int:
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
            s.bind(("", 0))
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            return s.getsockname()[1]
    
    
    def predict(self, batch: Dict[str, Any]) -> pd.DataFrame:
        predictions = ray.get([worker._predict_pandas.remote(batch) for worker in self.workers])
        combined_output = pd.concat(predictions, ignore_index=True)
        return combined_output


# +

model = AutoModelForCausalLM.from_pretrained(model_name)
model_ref = ray.put(model)

# +

from ray.util.actor_pool import ActorPool

num_parallel_models = 1
num_workers_per_model = 2
predictors = [DeepspeedPredictor.remote(model_ref, num_workers_per_model, i) for i in range(num_parallel_models)]

pool = ActorPool(predictors)

for it in ray_dataset.iter_torch_batches():
    pool.submit(lambda a, v: a.predict.remote(v), it)
while pool.has_next():
    print(pool.get_next())


# -
ray.shutdown()

Then I met:

2024-09-26 11:17:16,721 INFO worker.py:1598 -- Connecting to existing Ray cluster at address: 10.0.8.98:6379...
2024-09-26 11:17:16,733 INFO worker.py:1774 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 
MapBatches(tokenize)
+- Dataset(num_rows=100, schema={begginnings: object})
2024-09-26 11:18:27,225 INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-09-26_07-38-13_206865_51653/logs/ray-data
2024-09-26 11:18:27,225 INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(tokenize)]
✔️  Dataset execution finished in 13.34 seconds: : 100 row [00:13, 7.50 row/s]          ]
                                                   2024-09-26 11:18:40,567      WARNING actor_pool_map_operator.py:265 -- To ensure full parallelization across an actor pool of size 2, the Dataset should consist of at least 2 distinct blocks. Consider increasing the parallelism when creating the Dataset.
- MapBatches(tokenize): 0 active, 0 queued, [cpu: 1.0, objects: 0.0B], 0 actors [locality off]: : 100 row [00:00, 763 row/s]
(DeepspeedTPWorker pid=268595) is cuda available: True                                                                      
(DeepspeedTPWorker pid=268595) rank 0 assigned 1 CUDA_VISIBLE_DEVICES: 0
(DeepspeedTPWorker pid=268595) [2024-09-26 11:18:53,653] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
(DeepspeedTPWorker pid=268595) [2024-09-26 11:18:55,169] [INFO] [comm.py:637:init_distributed] cdb=None
(DeepspeedTPWorker pid=268595) [2024-09-26 11:18:55,169] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.14.5, git-hash=unknown, git-branch=unknown
(DeepspeedTPWorker pid=268595) [2024-09-26 11:18:55,171] [INFO] [logging.py:96:log_dist] [Rank 0] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
(DeepspeedTPWorker pid=268595) AutoTP:  [(<class 'transformers.models.llama.modeling_llama.LlamaDecoderLayer'>, ['self_attn.o_proj', 'mlp.down_proj'])]
Traceback (most recent call last):
  File "/home/ubuntu/oms/tutorials/notebooks/inference_NxN_TPDP_ray_deepspeed.py", line 222, in <module>
    print("Prediction output size:", pool.get_next())
  File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/util/actor_pool.py", line 309, in get_next
    return ray.get(future)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 2661, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RaySystemError): ray::DeepspeedPredictor.predict() (pid=268209, ip=10.0.8.98, actor_id=42678021ed8f98f0802215e423000000, repr=<inference_NxN_TPDP_ray_deepspeed.DeepspeedPredictor object at 0x7fb2c4350dc0>)
  At least one of the input arguments for this task could not be computed:
ray.exceptions.RaySystemError: System error: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
traceback: Traceback (most recent call last):
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/storage.py", line 381, in _load_from_bytes
    return torch.load(io.BytesIO(b))
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 1040, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 1272, in _legacy_load
    result = unpickler.load()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 1205, in persistent_load
    obj = restore_location(obj, location)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 390, in default_restore_location
    result = fn(storage, location)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 265, in _cuda_deserialize
    device = validate_cuda_device(location)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 249, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
(DeepspeedPredictor pid=268209) Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
(DeepspeedPredictor pid=268209) Traceback (most recent call last):
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/_private/serialization.py", line 423, in deserialize_objects
(DeepspeedPredictor pid=268209)     obj = self._deserialize_object(data, metadata, object_ref)
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/_private/serialization.py", line 280, in _deserialize_object
(DeepspeedPredictor pid=268209)     return self._deserialize_msgpack_data(data, metadata_fields)
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/_private/serialization.py", line 235, in _deserialize_msgpack_data
(DeepspeedPredictor pid=268209)     python_objects = self._deserialize_pickle5_data(pickle5_data)
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/ray/_private/serialization.py", line 225, in _deserialize_pickle5_data
(DeepspeedPredictor pid=268209)     obj = pickle.loads(in_band)
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/storage.py", line 381, in _load_from_bytes
(DeepspeedPredictor pid=268209)     return torch.load(io.BytesIO(b))
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 1040, in load
(DeepspeedPredictor pid=268209)     return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 1272, in _legacy_load
(DeepspeedPredictor pid=268209)     result = unpickler.load()
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 1205, in persistent_load
(DeepspeedPredictor pid=268209)     obj = restore_location(obj, location)
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 390, in default_restore_location
(DeepspeedPredictor pid=268209)     result = fn(storage, location)
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 265, in _cuda_deserialize
(DeepspeedPredictor pid=268209)     device = validate_cuda_device(location)
(DeepspeedPredictor pid=268209)   File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/serialization.py", line 249, in validate_cuda_device
(DeepspeedPredictor pid=268209)     raise RuntimeError('Attempting to deserialize object on a CUDA '
(DeepspeedPredictor pid=268209) RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
(DeepspeedTPWorker pid=268596) is cuda available: True
(DeepspeedTPWorker pid=268596) rank 1 assigned 1 CUDA_VISIBLE_DEVICES: 1
(DeepspeedTPWorker pid=268596) [2024-09-26 11:18:53,656] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
(DeepspeedTPWorker pid=268596) [2024-09-26 11:18:55,198] [INFO] [comm.py:637:init_distributed] cdb=None
(DeepspeedTPWorker pid=268596) AutoTP:  [(<class 'transformers.models.llama.modeling_llama.LlamaDecoderLayer'>, ['self_attn.o_proj', 'mlp.down_proj'])]

It seems that actor can’t find cuda when DeepspeedPredictor.predict fetch returns from DeepspeedTPWorker._predict_pandas. But I checked torch.cuda.is_available in DeepspeedTPWorker in my code and it returned True. I wander know what happen when running ray.get() in DeepspeedPredictor.predict

Can you try doing this inside __init__ instead of passing it via object reference?



model = AutoModelForCausalLM.from_pretrained(model_name)
model_ref = ray.put(model)

I changed my code but still met the same error. I added some stamps and found that when running pool.get_next() to execute DeepspeedPredictor .predict, the expected output of print("Starting predict") did not appear. Is the input of DeepspeedPredictor.predict(self, batch: Dict[str, Any]) automatically moved to GPU when being preprocessing?

# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.4
#   kernelspec:
#     display_name: .venv
#     language: python
#     name: python3
# ---

# +
# # %env HTTP_PROXY=http://127.0.0.1:7890
# # %env HTTPS_PROXY=http://127.0.0.1:7890
# # %env ALL_PROXY=socks5://127.0.0.1:7890

# +
import ray

ray.init()
ray.available_resources()
# -
import torch

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" 

# +
from transformers import AutoTokenizer


class tokenize:
    def __init__(self, model_name=model_name):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def __call__(self, text):
        batch_tokens = self.tokenizer(
            text["begginnings"].tolist(),
            return_tensors="pt",
            max_length=128,
            padding="max_length",
            truncation=True,
        )
        return dict(batch_tokens)



# +
import pandas as pd
from ray import data

begginnings_all = [
    "The King is dead. Long live the Queen.",
    "Once there were four children whose names were Peter, Susan, Edmund, and Lucy.",
    "The story so far: in the beginning, the universe was created.",
    "It was a bright cold day in April, and the clocks were striking thirteen.",
    "It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife.",
    "The sweat wis lashing oafay Sick Boy; he wis trembling.",
    "124 was spiteful. Full of Baby's venom.",
    "As Gregor Samsa awoke one morning from uneasy dreams he found himself transformed in his bed into a gigantic insect.",
    "I write this sitting in the kitchen sink.",
    "We were somewhere around Barstow on the edge of the desert when the drugs began to take hold.",
] * 10

dataset = data.from_pandas(pd.DataFrame({"begginnings": begginnings_all}))
ray_dataset = dataset.map_batches(tokenize, concurrency=1, batch_size=10)
print(ray_dataset)
# -

from contextlib import closing
import socket


# +
import ray.data

from transformers import AutoModelForCausalLM

import torch
import torch.distributed as dist



from typing import Dict, Any

import os

import pandas as pd



@ray.remote(num_cpus=2, num_gpus=1)
class DeepspeedTPWorker:
    def __init__(
        self,
        model_rank: int,
        local_rank: int,
        local_world_size: int,
        ip_address :str,
        port: int,
        model_name ,
        ds_config=None,
        dtype=torch.float16,
    ):
        
        print(f"is cuda available: {torch.cuda.is_available()}")
        self.model_name = model_name
        self.ds_config = ds_config

        self.model_rank = model_rank
        self.local_rank = local_rank
        self.local_world_size = local_world_size

        self.ip_address = ip_address
        self.port = port
        
        self.ds_config = ds_config
        self.dtype = dtype

        print(
            f"rank {self.local_rank} assigned {torch.cuda.device_count()} CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}"
        )

        dist.init_process_group(
            backend="nccl",
            init_method=f"tcp://{self.ip_address}:{self.port}",  
            rank=self.local_rank,
            world_size=self.local_world_size,
        )
        
        

        
    def load_ds_engine(self):
        import deepspeed
        
        deepspeed.init_distributed("nccl")
        self.model = deepspeed.init_inference(AutoModelForCausalLM.from_pretrained(self.model_name), self.ds_config)
        
        # with open(f"layers-{self.local_rank}","w") as f:
        #     for name, param in self.model.named_parameters():
        #         f.write(f"Rank: {self.local_rank}, Parameter: {name}, Mean: {param.mean().item()}\n")

    def _predict_pandas(self, batch: Dict[str, Any]) -> pd.DataFrame:
        print("Starting _predict_pandas")
        with torch.no_grad():
            out_tokens = self.model.generate(**batch)
            out_tokens = out_tokens.to("cpu")
        
        decode_result = self.tokenizer.batch_decode(
            out_tokens, skip_special_tokens=True
        )
        print("Decode Result: ", decode_result)
        
        return pd.DataFrame(
            {
                "generated_tokens": out_tokens,
                "decoded_texts": decode_result,
            }
        )

# +

@ray.remote(num_cpus=2)
class DeepspeedPredictor:
    def __init__(self, model_name, num_workers_per_model, model_rank) -> None:
        self.model_name = model_name
        self.num_workers_per_model = num_workers_per_model
        self.ds_config = {
            "tensor_parallel": {
                    "enabled": True,
                    "tp_size": num_workers_per_model,
            },
        }
        self.model_rank = model_rank
        self.comm_port = self._find_free_port()
        self.comm_master_ip = ray.util.get_node_ip_address()
        
    def init_worker_group(self):
        
        self.workers = [
            DeepspeedTPWorker.remote(self.model_rank, i, self.num_workers_per_model, self.comm_master_ip, self.comm_port, self.model_name, self.ds_config)
            for i in range(self.num_workers_per_model)
        ]
        
    def load_model(self):
        ray.get([p.load_ds_engine.remote() for p in self.workers])
        
    
    def _find_free_port(self) -> int:
        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
            s.bind(("", 0))
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            return s.getsockname()[1]
    
    
    def predict(self, batch: Dict[str, Any]) -> pd.DataFrame:
        print("Starting predict")
        predictions = ray.get([worker._predict_pandas.remote(batch) for worker in self.workers])
        print("Finished creating predictions")
        combined_output = pd.concat(predictions, ignore_index=True)
        return combined_output


# +

# model = AutoModelForCausalLM.from_pretrained(model_name)
# model_ref = ray.put(model)

# +

from ray.util.actor_pool import ActorPool

num_parallel_models = 1
num_workers_per_model = 2

print("Setting up predictors")
predictors = [DeepspeedPredictor.remote(model_name, num_workers_per_model, i) for i in range(num_parallel_models)]
ray.get([p.init_worker_group.remote() for p in predictors])
ray.get([p.load_model.remote() for p in predictors])

pool = ActorPool(predictors)
print("Created ActorPool")

for it in ray_dataset.iter_torch_batches():
    print("Starting predict")
    pool.submit(lambda a, v: a.predict.remote(v), it)
    print("Finished submitting predict")
while pool.has_next():
    print(pool.get_next())


# -
ray.shutdown()

the outputs of the code before

2024-10-01 07:32:28,214 INFO worker.py:1598 -- Connecting to existing Ray cluster at address: 10.80.0.53:6379...
2024-10-01 07:32:28,226 INFO worker.py:1774 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 
MapBatches(tokenize)
+- Dataset(num_rows=100, schema={begginnings: object})
Setting up predictors
(DeepspeedTPWorker pid=1475631) is cuda available: True
(DeepspeedTPWorker pid=1475631) rank 0 assigned 1 CUDA_VISIBLE_DEVICES: 0
(DeepspeedTPWorker pid=1475631) [2024-10-01 07:32:38,474] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
(DeepspeedTPWorker pid=1475631) [2024-10-01 07:32:40,083] [INFO] [comm.py:637:init_distributed] cdb=None
(DeepspeedTPWorker pid=1475631) [2024-10-01 07:32:42,437] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.14.5, git-hash=unknown, git-branch=unknown
(DeepspeedTPWorker pid=1475632) is cuda available: True
(DeepspeedTPWorker pid=1475632) rank 1 assigned 1 CUDA_VISIBLE_DEVICES: 1
(DeepspeedTPWorker pid=1475631) [2024-10-01 07:32:42,438] [INFO] [logging.py:96:log_dist] [Rank 0] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
(DeepspeedTPWorker pid=1475631) AutoTP:  [(<class 'transformers.models.llama.modeling_llama.LlamaDecoderLayer'>, ['mlp.down_proj', 'self_attn.o_proj'])]
Created ActorPool
2024-10-01 07:32:44,220 INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-10-01_05-31-29_931398_1323789/logs/ray-data
2024-10-01 07:32:44,220 INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(tokenize)]
✔️  Dataset execution finished in 2.60 seconds: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 38.5 row/s] 
- MapBatches(tokenize): 0 active, 0 queued, [cpu: 1.0, objects: 200.8KB], 0 actors [locality off]: 100%|████████████████████████████████████████████| 100/100 [00:00<00:00, 1.71k row/s]
Starting predict                                                                                                                                                                        
Finished submitting predict
(DeepspeedPredictor pid=1475533) Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
(DeepspeedPredictor pid=1475533) Traceback (most recent call last):
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/ray/_private/serialization.py", line 423, in deserialize_objects
(DeepspeedPredictor pid=1475533)     obj = self._deserialize_object(data, metadata, object_ref)
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/ray/_private/serialization.py", line 280, in _deserialize_object
(DeepspeedPredictor pid=1475533)     return self._deserialize_msgpack_data(data, metadata_fields)
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/ray/_private/serialization.py", line 235, in _deserialize_msgpack_data
(DeepspeedPredictor pid=1475533)     python_objects = self._deserialize_pickle5_data(pickle5_data)
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/ray/_private/serialization.py", line 225, in _deserialize_pickle5_data
(DeepspeedPredictor pid=1475533)     obj = pickle.loads(in_band)
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/torch/storage.py", line 381, in _load_from_bytes
(DeepspeedPredictor pid=1475533)     return torch.load(io.BytesIO(b))
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 1040, in load
(DeepspeedPredictor pid=1475533)     return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 1272, in _legacy_load
(DeepspeedPredictor pid=1475533)     result = unpickler.load()
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 1205, in persistent_load
(DeepspeedPredictor pid=1475533)     obj = restore_location(obj, location)
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 390, in default_restore_location
(DeepspeedPredictor pid=1475533)     result = fn(storage, location)
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 265, in _cuda_deserialize
(DeepspeedPredictor pid=1475533)     device = validate_cuda_device(location)
(DeepspeedPredictor pid=1475533)   File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 249, in validate_cuda_device
(DeepspeedPredictor pid=1475533)     raise RuntimeError('Attempting to deserialize object on a CUDA '
(DeepspeedPredictor pid=1475533) RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
Traceback (most recent call last):
  File "/home/test/Desktop/oms/tutorials/notebooks/inference_NxN_TPDP_ray_deepspeed.py", line 233, in <module>
    print(pool.get_next())
  File "/home/test/.local/lib/python3.10/site-packages/ray/util/actor_pool.py", line 309, in get_next
    return ray.get(future)
  File "/home/test/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/test/.local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/test/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 2661, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/test/.local/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RaySystemError): ray::DeepspeedPredictor.predict() (pid=1475533, ip=10.80.0.53, actor_id=223ca431ec4794bf6361536a14000000, repr=<inference_NxN_TPDP_ray_deepspeed.DeepspeedPredictor object at 0x7b00fc0f65f0>)
  At least one of the input arguments for this task could not be computed:
ray.exceptions.RaySystemError: System error: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
traceback: Traceback (most recent call last):
  File "/home/test/.local/lib/python3.10/site-packages/torch/storage.py", line 381, in _load_from_bytes
    return torch.load(io.BytesIO(b))
  File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 1040, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 1272, in _legacy_load
    result = unpickler.load()
  File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 1205, in persistent_load
    obj = restore_location(obj, location)
  File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 390, in default_restore_location
    result = fn(storage, location)
  File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 265, in _cuda_deserialize
    device = validate_cuda_device(location)
  File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 249, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
(DeepspeedTPWorker pid=1475632) [2024-10-01 07:32:38,478] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
(DeepspeedTPWorker pid=1475632) [2024-10-01 07:32:40,077] [INFO] [comm.py:637:init_distributed] cdb=None
(DeepspeedTPWorker pid=1475632) AutoTP:  [(<class 'transformers.models.llama.modeling_llama.LlamaDecoderLayer'>, ['mlp.down_proj', 'self_attn.o_proj'])]