How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
I am trying ray-serve to serve my model, unfortunately I got stuck on strange error “TypeError: cannot pickle ‘classmethod_descriptor’ object”
Here is my code
import argparse
from typing import Optional, List, Tuple
import torch
from torch.cuda import get_device_properties
from transformers import AutoModel, AutoTokenizer
import requests
from fastapi import FastAPI
from ray import serve
from ray.serve import serve
parser = argparse.ArgumentParser()
parser.add_argument(“–port”, type=int, default=“17860”)
parser.add_argument(“–model-path”, type=str, default=“THUDM/chatglm-6b”)
parser.add_argument(“–precision”, type=str, help=“evaluate at this precision”,
choices=[“fp32”, “fp16”, “int4”, “int8”])
parser.add_argument(“–listen”, action=‘store_true’,
help=“listen 0.0.0.0, allowing to respond to network requests”)
parser.add_argument(“–cpu”, action=‘store_true’, help=“use cpu”)
parser.add_argument(“–device-id”, type=str,
help=“select the default CUDA device to use”, default=None)
cmd_opts = parser.parse_args()
def load_model():
# load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(
cmd_opts.model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(
cmd_opts.model_path, trust_remote_code=True)
# load model with precision
if cmd_opts.cpu:
if cmd_opts.precision == "fp32":
model = model.float()
elif cmd_opts.precision == "bf16":
model = model.bfloat16()
else:
model = model.float()
else:
if cmd_opts.precision is None:
total_vram_in_gb = get_device_properties(0).total_memory / 1e9
print(f'GPU memory: {total_vram_in_gb:.2f} GB')
if total_vram_in_gb > 30:
cmd_opts.precision = 'fp32'
elif total_vram_in_gb > 13:
cmd_opts.precision = 'fp16'
elif total_vram_in_gb > 10:
cmd_opts.precision = 'int8'
else:
cmd_opts.precision = 'int4'
print(f'Choosing precision {cmd_opts.precision} according to your VRAM.'
f' If you want to decide precision yourself,'
f' please add argument --precision when launching the application.')
if cmd_opts.precision == "fp16":
model = model.half().cuda()
elif cmd_opts.precision == "int4":
model = model.half().quantize(4).cuda()
elif cmd_opts.precision == "int8":
model = model.half().quantize(8).cuda()
elif cmd_opts.precision == "fp32":
model = model.float()
# load model in eval mode
model = model.eval()
return model, tokenizer
def infer(model, tokenizer, query,
history: Optional[List[Tuple]],
max_length, top_p, temperature, use_stream_chat: bool):
# Define the function to generate a response given a query
if not model:
raise “Model not loaded”
if history is None:
history = []
output_pos = 0
try:
if use_stream_chat:
for output, history in model.stream_chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
):
# print(output[output_pos:], end='', flush=True)
old_output_pos = output_pos
output_pos = len(output)
yield query, output[old_output_pos:] # output[output_pos:]
else:
output, history = model.chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
)
# print(output)
yield query, output
except Exception as e:
print(f"Generation failed: {repr(e)}")
# Free up GPU memory
if torch.cuda.is_available():
device = torch.device(
f"cuda:{cmd_opts.device_id}" if cmd_opts.device_id is not None else "cuda")
with torch.cuda.device(device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
model, tokenizer = load_model()
1: Define a FastAPI app and wrap it in a deployment with a route handler.
app = FastAPI()
@serve.deployment(route_prefix=“/”)
@serve.ingress(app)
class ChatbotModel:
def predict(self, query, max_length, top_p, temperature, use_stream_chat):
# Define the function to handle HTTP requests and generate responses
responses = []
for _, output in infer(
model=model,
tokenizer=tokenizer,
query=query,
history=None, # ctx.history,
max_length=max_length,
top_p=top_p,
temperature=temperature,
use_stream_chat=use_stream_chat
):
responses.append(output.strip())
print(output)
# Return the response as a JSON object
return {"response": responses}
# FastAPI will automatically parse the HTTP request for us.
@app.get("/hello")
def predict_hello(self, query: str):
return self.predict(query, max_length=2048, top_p=0.9,
temperature=0.7, use_stream_chat=False)
2: Deploy the deployment.
serve.start()
ChatbotModel.deploy()
3: Query the deployment and print the result.
print(requests.get(“http://localhost:8000/hello”,
params={“query”: “What is transformer”}).json()[“response”])
"
seems good to me, except the error pops up when ray tries to pickle the ChatbotModel class