diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 8ae390b610..93373f5a88 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -97,6 +97,11 @@ func main() { Usage: "Do not delete container on exit", Destination: &args.Docker.KeepContainer, }, + &cli.BoolFlag{ + Name: "privileged", + Usage: "Give extended privileges to the container", + Destination: &args.Docker.Privileged, + }, &cli.StringFlag{ Name: "ssh-key", Usage: "Public SSH key", @@ -104,6 +109,12 @@ func main() { Destination: &args.Docker.ConcatinatedPublicSSHKeys, EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"}, }, + &cli.StringFlag{ + Name: "pjrt-device", + Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)", + Destination: &args.Docker.PJRTDevice, + EnvVars: []string{"PJRT_DEVICE"}, + }, &cli.BoolFlag{ Name: "service", Usage: "Start as a service", diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 935a934d56..49573f1fd5 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -9,6 +9,7 @@ import ( "os/exec" "path/filepath" "strconv" + "strings" "sync" "syscall" "time" @@ -205,6 +206,15 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error "DSTACK_GPUS_NUM": strconv.Itoa(gpus_num), } + // Call buildLDLibraryPathEnv and update jobEnvs if no error occurs + newLDPath, err := buildLDLibraryPathEnv() + if err != nil { + log.Info(ctx, "Continuing without updating LD_LIBRARY_PATH") + } else { + jobEnvs["LD_LIBRARY_PATH"] = newLDPath + log.Info(ctx, "New LD_LIBRARY_PATH set", newLDPath) + } + cmd := exec.CommandContext(ctx, ex.jobSpec.Commands[0], ex.jobSpec.Commands[1:]...) cmd.Env = makeEnv(ex.homeDir, jobEnvs, ex.jobSpec.Env, ex.secrets) cmd.Cancel = func() error { @@ -291,3 +301,38 @@ func isPtyError(err error) bool { var e *os.PathError return errors.As(err, &e) && e.Err == syscall.EIO } + +func buildLDLibraryPathEnv() (string, error) { + // Execute shell command to get Python prefix + cmd := exec.Command("bash", "-i", "-c", "python3-config --prefix") + output, err := cmd.Output() + + if err != nil { + return "", fmt.Errorf("error executing command: %v", err) + } + + // Extract and trim the prefix path + prefixPath := strings.TrimSpace(string(output)) + + // Check if the prefix path exists + if _, err := os.Stat(prefixPath); os.IsNotExist(err) { + return "", fmt.Errorf("python prefix path does not exist: %s", prefixPath) + } + + // Construct the path to Python's shared libraries + sharedLibPath := fmt.Sprintf("%s/lib", prefixPath) + + // Get current LD_LIBRARY_PATH + currentLDPath := os.Getenv("LD_LIBRARY_PATH") + + // Append Python's shared library path if not already present + if !strings.Contains(currentLDPath, sharedLibPath) { + if currentLDPath == "" { + currentLDPath = sharedLibPath + } else { + currentLDPath = fmt.Sprintf("%s:%s", currentLDPath, sharedLibPath) + } + } + + return currentLDPath, nil +} diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index ac174ffda6..4d0cda007f 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -312,13 +312,21 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str return "", tracerr.Wrap(err) } + //Set the environment variables + envVars := []string{} + if dockerParams.DockerPJRTDevice() != "" { + envVars = append(envVars, fmt.Sprintf("PJRT_DEVICE=%s", dockerParams.DockerPJRTDevice())) + } + containerConfig := &container.Config{ Image: taskConfig.ImageName, Cmd: []string{strings.Join(dockerParams.DockerShellCommands(taskConfig.PublicKeys), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, ExposedPorts: exposePorts(dockerParams.DockerPorts()...), + Env: envVars, } hostConfig := &container.HostConfig{ + Privileged: dockerParams.DockerPrivileged(), NetworkMode: getNetworkMode(), PortBindings: bindPorts(dockerParams.DockerPorts()...), PublishAllPorts: true, @@ -426,6 +434,14 @@ func (c CLIArgs) DockerKeepContainer() bool { return c.Docker.KeepContainer } +func (c CLIArgs) DockerPrivileged() bool { + return c.Docker.Privileged +} + +func (c CLIArgs) DockerPJRTDevice() string { + return c.Docker.PJRTDevice +} + func (c CLIArgs) DockerShellCommands(publicKeys []string) []string { concatinatedPublicKeys := c.Docker.ConcatinatedPublicSSHKeys if len(publicKeys) > 0 { diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 5db6d523a1..5823bdfdf1 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -98,6 +98,14 @@ func (c *dockerParametersMock) DockerKeepContainer() bool { return false } +func (c *dockerParametersMock) DockerPrivileged() bool { + return false +} + +func (c *dockerParametersMock) DockerPJRTDevice() string { + return "" +} + func (c *dockerParametersMock) DockerShellCommands(publicKeys []string) []string { userPublicKey := c.publicSSHKey if len(publicKeys) > 0 { diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 9de45d8631..d596ef5c39 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -10,11 +10,13 @@ import ( ) type DockerParameters interface { + DockerPrivileged() bool DockerKeepContainer() bool DockerShellCommands([]string) []string DockerMounts(string) ([]mount.Mount, error) DockerPorts() []int MakeRunnerDir() (string, error) + DockerPJRTDevice() string } type CLIArgs struct { @@ -38,6 +40,8 @@ type CLIArgs struct { SSHPort int KeepContainer bool ConcatinatedPublicSSHKeys string + Privileged bool + PJRTDevice string } } diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 35e484dc92..988c37fbd9 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -20,7 +20,6 @@ logger = get_logger(__name__) - DSTACK_WORKING_DIR = "/root/.dstack" @@ -121,14 +120,16 @@ def get_shim_env(build: str, authorized_keys: List[str]) -> Dict[str, str]: return envs -def get_shim_commands(authorized_keys: List[str]) -> List[str]: +def get_shim_commands( + authorized_keys: List[str], *, is_privileged: bool = False, pjrt_device: Optional[str] = None +) -> List[str]: build = get_dstack_runner_version() commands = get_shim_pre_start_commands( build, ) for k, v in get_shim_env(build, authorized_keys).items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script() + commands += get_run_shim_script(is_privileged, pjrt_device) return commands @@ -161,10 +162,13 @@ def get_shim_pre_start_commands(build: str) -> List[str]: ] -def get_run_shim_script() -> List[str]: +def get_run_shim_script(is_privileged: bool, pjrt_device: Optional[str]) -> List[str]: dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" + privileged_flag = "--privileged" if is_privileged else "" + pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else "" + return [ - f"nohup dstack-shim {dev_flag} docker --keep-container >{DSTACK_WORKING_DIR}/shim.log 2>&1 &", + f"nohup dstack-shim {dev_flag} docker --keep-container {privileged_flag} {pjrt_device_env} >{DSTACK_WORKING_DIR}/shim.log 2>&1 &", ] diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 475bc2c722..62bfd8a66d 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -443,7 +443,9 @@ def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]: def _get_tpu_startup_script(authorized_keys: List[str]) -> str: - commands = get_shim_commands(authorized_keys=authorized_keys) + commands = get_shim_commands( + authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU" + ) startup_script = " ".join([" && ".join(commands)]) startup_script = "#! /bin/bash\n" + startup_script return startup_script @@ -469,9 +471,9 @@ def _is_pod(instance_name: str) -> bool: tensor_cores = int(tensor_cores) except ValueError: raise ValueError(f"Invalid number in tpu tensor cores: {tensor_cores}") - if version in ["v2", "v3"]: + if version in ["v2", "v3", "v5p", "v5litepod"]: return tensor_cores > 8 - elif version in ["v4", "v5p", "v5litepod"]: + elif version == "v4": return True else: raise ValueError(f"Unknown TPU version: {version}")