Ray Tune not working inside Databricks

Just as an idea, how can I override PYTHONPATH, so I can remove the path of PYSPARK from there. Maybe it solves the issue

I have tried updating PYTHONPATH on ray start and ray.init() without success:

#!/bin/bash

#RAY PORT
RAY_PORT=9339
REDIS_PASS="d4t4bricks"

# install ray
/databricks/python/bin/pip install ray==1.13.0

# Install additional ray libraries
/databricks/python/bin/pip install ray[debug,dashboard,tune,rllib,serve]==1.13.0

# If starting on the Spark driver node, initialize the Ray head node
# If starting on the Spark worker node, connect to the head Ray node
if [ ! -z $DB_IS_DRIVER ] && [ $DB_IS_DRIVER = TRUE ] ; then
  echo "Starting the head node"
  PYTHONPATH=/databricks/python/lib/python3.8/site-packages ray start  --min-worker-port=20000 --max-worker-port=25000 --temp-dir="/tmp/ray" --head --port=$RAY_PORT --redis-password="$REDIS_PASS"  --include-dashboard=false
else
  sleep 40
  echo "Starting the non-head node - connecting to $DB_DRIVER_IP:$RAY_PORT"
  PYTHONPATH=/databricks/python/lib/python3.8/site-packages ray start  --min-worker-port=20000 --max-worker-port=25000 --temp-dir="/tmp/ray" --address="$DB_DRIVER_IP:$RAY_PORT" --redis-password="$REDIS_PASS"
fi
covid_df = (spark
            .read
            .option("header", "true") 
            .option('inferSchema', 'true')
            .csv('/databricks-datasets/COVID/USAFacts/covid_deaths_usafacts.csv'))
 
select_cols = covid_df.columns[4:]
 
df = (covid_df
     .select(
       col('County Name').alias('county_name'),
       array([col(n) for n in select_cols]
       ).alias('deaths')))

@ray.remote
def linear_pred(x,y, i):
    import sys
    raise Exception(sys.path)
    import pandas as pd
    reg = linear_model.ElasticNet().fit(x, y)
    p = reg.predict(np.array([[i + 1]]))
    return p[0]
  
 
@pandas_udf(ArrayType(LongType()))
def ray_udf(s):
    ray.init(ignore_reinit_error=True, address='auto', _redis_password='d4t4bricks', runtime_env = {"env_vars": {"PYTHONPATH": "/databricks/python/lib/python3.8/site-packages"}})
    s = list(s)
   
    pred = []
    workers = []
    for i in range(len(s)):
        x = list(range(i+1))
        x = np.asarray([[n] for n in x])
        y = s[:i+1]
        y = np.asarray(y)
 
        workers.append(linear_pred.remote(x, y, i))
 
    pred = ray.get(workers)
    return pd.Series(pred)
 
res = df.select("county_name", "deaths", ray_udf("deaths").alias("preds"))
display(res)