From 924c823eb806c4dcf543c48fd964218cd461b1f8 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Mon, 24 Jun 2024 18:45:09 +0545 Subject: [PATCH 1/4] tpu initial release --- runner/cmd/shim/main.go | 5 +++++ runner/internal/shim/docker.go | 8 ++++++++ runner/internal/shim/models.go | 2 ++ src/dstack/_internal/core/backends/base/compute.py | 10 +++++----- src/dstack/_internal/core/backends/gcp/compute.py | 6 +++--- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 8ae390b610..6edd40d6d7 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -97,6 +97,11 @@ func main() { Usage: "Do not delete container on exit", Destination: &args.Docker.KeepContainer, }, + &cli.BoolFlag{ + Name: "privileged", + Usage: "Give extended privileges to the container", + Destination: &args.Docker.Privileged, + }, &cli.StringFlag{ Name: "ssh-key", Usage: "Public SSH key", diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index ac174ffda6..cdc84f1e7c 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -317,8 +317,12 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str Cmd: []string{strings.Join(dockerParams.DockerShellCommands(taskConfig.PublicKeys), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, ExposedPorts: exposePorts(dockerParams.DockerPorts()...), + Env: []string{ + "PJRT_DEVICE=TPU", + }, } hostConfig := &container.HostConfig{ + Privileged: dockerParams.DockerPrivileged(), NetworkMode: getNetworkMode(), PortBindings: bindPorts(dockerParams.DockerPorts()...), PublishAllPorts: true, @@ -426,6 +430,10 @@ func (c CLIArgs) DockerKeepContainer() bool { return c.Docker.KeepContainer } +func (c CLIArgs) DockerPrivileged() bool { + return c.Docker.Privileged +} + func (c CLIArgs) DockerShellCommands(publicKeys []string) []string { concatinatedPublicKeys := c.Docker.ConcatinatedPublicSSHKeys if len(publicKeys) > 0 { diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 9de45d8631..e04fe16b02 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -10,6 +10,7 @@ import ( ) type DockerParameters interface { + DockerPrivileged() bool DockerKeepContainer() bool DockerShellCommands([]string) []string DockerMounts(string) ([]mount.Mount, error) @@ -38,6 +39,7 @@ type CLIArgs struct { SSHPort int KeepContainer bool ConcatinatedPublicSSHKeys string + Privileged bool } } diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 35e484dc92..a9b49d55e4 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -20,7 +20,6 @@ logger = get_logger(__name__) - DSTACK_WORKING_DIR = "/root/.dstack" @@ -121,14 +120,14 @@ def get_shim_env(build: str, authorized_keys: List[str]) -> Dict[str, str]: return envs -def get_shim_commands(authorized_keys: List[str]) -> List[str]: +def get_shim_commands(authorized_keys: List[str], *, is_privileged: bool = False) -> List[str]: build = get_dstack_runner_version() commands = get_shim_pre_start_commands( build, ) for k, v in get_shim_env(build, authorized_keys).items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script() + commands += get_run_shim_script(is_privileged) return commands @@ -161,10 +160,11 @@ def get_shim_pre_start_commands(build: str) -> List[str]: ] -def get_run_shim_script() -> List[str]: +def get_run_shim_script(is_privileged: bool) -> List[str]: dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" + privileged_flag = "--privileged" if is_privileged else "" return [ - f"nohup dstack-shim {dev_flag} docker --keep-container >{DSTACK_WORKING_DIR}/shim.log 2>&1 &", + f"nohup dstack-shim {dev_flag} docker --keep-container {privileged_flag} >{DSTACK_WORKING_DIR}/shim.log 2>&1 &", ] diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 475bc2c722..83c4cd2196 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -443,7 +443,7 @@ def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]: def _get_tpu_startup_script(authorized_keys: List[str]) -> str: - commands = get_shim_commands(authorized_keys=authorized_keys) + commands = get_shim_commands(authorized_keys=authorized_keys, is_privileged=True) startup_script = " ".join([" && ".join(commands)]) startup_script = "#! /bin/bash\n" + startup_script return startup_script @@ -469,9 +469,9 @@ def _is_pod(instance_name: str) -> bool: tensor_cores = int(tensor_cores) except ValueError: raise ValueError(f"Invalid number in tpu tensor cores: {tensor_cores}") - if version in ["v2", "v3"]: + if version in ["v2", "v3", "v5p", "v5litepod"]: return tensor_cores > 8 - elif version in ["v4", "v5p", "v5litepod"]: + elif version == "v4": return True else: raise ValueError(f"Unknown TPU version: {version}") From 82784cfd8d0bbda3aec2696ba1892d4bbdc4c8c9 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Mon, 24 Jun 2024 18:45:09 +0545 Subject: [PATCH 2/4] tpu initial release --- runner/cmd/shim/main.go | 5 +++++ runner/internal/shim/docker.go | 8 ++++++++ runner/internal/shim/docker_test.go | 4 ++++ runner/internal/shim/models.go | 2 ++ src/dstack/_internal/core/backends/base/compute.py | 10 +++++----- src/dstack/_internal/core/backends/gcp/compute.py | 6 +++--- 6 files changed, 27 insertions(+), 8 deletions(-) diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 8ae390b610..6edd40d6d7 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -97,6 +97,11 @@ func main() { Usage: "Do not delete container on exit", Destination: &args.Docker.KeepContainer, }, + &cli.BoolFlag{ + Name: "privileged", + Usage: "Give extended privileges to the container", + Destination: &args.Docker.Privileged, + }, &cli.StringFlag{ Name: "ssh-key", Usage: "Public SSH key", diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index ac174ffda6..cdc84f1e7c 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -317,8 +317,12 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str Cmd: []string{strings.Join(dockerParams.DockerShellCommands(taskConfig.PublicKeys), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, ExposedPorts: exposePorts(dockerParams.DockerPorts()...), + Env: []string{ + "PJRT_DEVICE=TPU", + }, } hostConfig := &container.HostConfig{ + Privileged: dockerParams.DockerPrivileged(), NetworkMode: getNetworkMode(), PortBindings: bindPorts(dockerParams.DockerPorts()...), PublishAllPorts: true, @@ -426,6 +430,10 @@ func (c CLIArgs) DockerKeepContainer() bool { return c.Docker.KeepContainer } +func (c CLIArgs) DockerPrivileged() bool { + return c.Docker.Privileged +} + func (c CLIArgs) DockerShellCommands(publicKeys []string) []string { concatinatedPublicKeys := c.Docker.ConcatinatedPublicSSHKeys if len(publicKeys) > 0 { diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 5db6d523a1..5be30ce0b2 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -98,6 +98,10 @@ func (c *dockerParametersMock) DockerKeepContainer() bool { return false } +func (c *dockerParametersMock) DockerPrivileged() bool { + return false +} + func (c *dockerParametersMock) DockerShellCommands(publicKeys []string) []string { userPublicKey := c.publicSSHKey if len(publicKeys) > 0 { diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 9de45d8631..e04fe16b02 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -10,6 +10,7 @@ import ( ) type DockerParameters interface { + DockerPrivileged() bool DockerKeepContainer() bool DockerShellCommands([]string) []string DockerMounts(string) ([]mount.Mount, error) @@ -38,6 +39,7 @@ type CLIArgs struct { SSHPort int KeepContainer bool ConcatinatedPublicSSHKeys string + Privileged bool } } diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 35e484dc92..a9b49d55e4 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -20,7 +20,6 @@ logger = get_logger(__name__) - DSTACK_WORKING_DIR = "/root/.dstack" @@ -121,14 +120,14 @@ def get_shim_env(build: str, authorized_keys: List[str]) -> Dict[str, str]: return envs -def get_shim_commands(authorized_keys: List[str]) -> List[str]: +def get_shim_commands(authorized_keys: List[str], *, is_privileged: bool = False) -> List[str]: build = get_dstack_runner_version() commands = get_shim_pre_start_commands( build, ) for k, v in get_shim_env(build, authorized_keys).items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script() + commands += get_run_shim_script(is_privileged) return commands @@ -161,10 +160,11 @@ def get_shim_pre_start_commands(build: str) -> List[str]: ] -def get_run_shim_script() -> List[str]: +def get_run_shim_script(is_privileged: bool) -> List[str]: dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" + privileged_flag = "--privileged" if is_privileged else "" return [ - f"nohup dstack-shim {dev_flag} docker --keep-container >{DSTACK_WORKING_DIR}/shim.log 2>&1 &", + f"nohup dstack-shim {dev_flag} docker --keep-container {privileged_flag} >{DSTACK_WORKING_DIR}/shim.log 2>&1 &", ] diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 475bc2c722..83c4cd2196 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -443,7 +443,7 @@ def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]: def _get_tpu_startup_script(authorized_keys: List[str]) -> str: - commands = get_shim_commands(authorized_keys=authorized_keys) + commands = get_shim_commands(authorized_keys=authorized_keys, is_privileged=True) startup_script = " ".join([" && ".join(commands)]) startup_script = "#! /bin/bash\n" + startup_script return startup_script @@ -469,9 +469,9 @@ def _is_pod(instance_name: str) -> bool: tensor_cores = int(tensor_cores) except ValueError: raise ValueError(f"Invalid number in tpu tensor cores: {tensor_cores}") - if version in ["v2", "v3"]: + if version in ["v2", "v3", "v5p", "v5litepod"]: return tensor_cores > 8 - elif version in ["v4", "v5p", "v5litepod"]: + elif version == "v4": return True else: raise ValueError(f"Unknown TPU version: {version}") From a046bfae76606e11d32ecd9d8df56bc05cdec9d2 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Tue, 25 Jun 2024 23:13:20 +0545 Subject: [PATCH 3/4] Fix env variable setting mechanism --- runner/cmd/shim/main.go | 8 +++- runner/internal/executor/executor.go | 45 +++++++++++++++++++ runner/internal/shim/docker.go | 8 +++- runner/internal/shim/docker_test.go | 4 ++ runner/internal/shim/models.go | 4 +- .../_internal/core/backends/base/compute.py | 12 +++-- .../_internal/core/backends/gcp/compute.py | 4 +- 7 files changed, 76 insertions(+), 9 deletions(-) diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 6edd40d6d7..93373f5a88 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -97,7 +97,7 @@ func main() { Usage: "Do not delete container on exit", Destination: &args.Docker.KeepContainer, }, - &cli.BoolFlag{ + &cli.BoolFlag{ Name: "privileged", Usage: "Give extended privileges to the container", Destination: &args.Docker.Privileged, @@ -109,6 +109,12 @@ func main() { Destination: &args.Docker.ConcatinatedPublicSSHKeys, EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"}, }, + &cli.StringFlag{ + Name: "pjrt-device", + Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)", + Destination: &args.Docker.PJRTDevice, + EnvVars: []string{"PJRT_DEVICE"}, + }, &cli.BoolFlag{ Name: "service", Usage: "Start as a service", diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 935a934d56..49573f1fd5 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -9,6 +9,7 @@ import ( "os/exec" "path/filepath" "strconv" + "strings" "sync" "syscall" "time" @@ -205,6 +206,15 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error "DSTACK_GPUS_NUM": strconv.Itoa(gpus_num), } + // Call buildLDLibraryPathEnv and update jobEnvs if no error occurs + newLDPath, err := buildLDLibraryPathEnv() + if err != nil { + log.Info(ctx, "Continuing without updating LD_LIBRARY_PATH") + } else { + jobEnvs["LD_LIBRARY_PATH"] = newLDPath + log.Info(ctx, "New LD_LIBRARY_PATH set", newLDPath) + } + cmd := exec.CommandContext(ctx, ex.jobSpec.Commands[0], ex.jobSpec.Commands[1:]...) cmd.Env = makeEnv(ex.homeDir, jobEnvs, ex.jobSpec.Env, ex.secrets) cmd.Cancel = func() error { @@ -291,3 +301,38 @@ func isPtyError(err error) bool { var e *os.PathError return errors.As(err, &e) && e.Err == syscall.EIO } + +func buildLDLibraryPathEnv() (string, error) { + // Execute shell command to get Python prefix + cmd := exec.Command("bash", "-i", "-c", "python3-config --prefix") + output, err := cmd.Output() + + if err != nil { + return "", fmt.Errorf("error executing command: %v", err) + } + + // Extract and trim the prefix path + prefixPath := strings.TrimSpace(string(output)) + + // Check if the prefix path exists + if _, err := os.Stat(prefixPath); os.IsNotExist(err) { + return "", fmt.Errorf("python prefix path does not exist: %s", prefixPath) + } + + // Construct the path to Python's shared libraries + sharedLibPath := fmt.Sprintf("%s/lib", prefixPath) + + // Get current LD_LIBRARY_PATH + currentLDPath := os.Getenv("LD_LIBRARY_PATH") + + // Append Python's shared library path if not already present + if !strings.Contains(currentLDPath, sharedLibPath) { + if currentLDPath == "" { + currentLDPath = sharedLibPath + } else { + currentLDPath = fmt.Sprintf("%s:%s", currentLDPath, sharedLibPath) + } + } + + return currentLDPath, nil +} diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index cdc84f1e7c..7391c22bc4 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -318,11 +318,11 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str Entrypoint: []string{"/bin/sh", "-c"}, ExposedPorts: exposePorts(dockerParams.DockerPorts()...), Env: []string{ - "PJRT_DEVICE=TPU", + fmt.Sprintf("PJRT_DEVICE=%s", dockerParams.DockerPJRTDevice()), }, } hostConfig := &container.HostConfig{ - Privileged: dockerParams.DockerPrivileged(), + Privileged: dockerParams.DockerPrivileged(), NetworkMode: getNetworkMode(), PortBindings: bindPorts(dockerParams.DockerPorts()...), PublishAllPorts: true, @@ -434,6 +434,10 @@ func (c CLIArgs) DockerPrivileged() bool { return c.Docker.Privileged } +func (c CLIArgs) DockerPJRTDevice() string { + return c.Docker.PJRTDevice +} + func (c CLIArgs) DockerShellCommands(publicKeys []string) []string { concatinatedPublicKeys := c.Docker.ConcatinatedPublicSSHKeys if len(publicKeys) > 0 { diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 5be30ce0b2..5823bdfdf1 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -102,6 +102,10 @@ func (c *dockerParametersMock) DockerPrivileged() bool { return false } +func (c *dockerParametersMock) DockerPJRTDevice() string { + return "" +} + func (c *dockerParametersMock) DockerShellCommands(publicKeys []string) []string { userPublicKey := c.publicSSHKey if len(publicKeys) > 0 { diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index e04fe16b02..d596ef5c39 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -10,12 +10,13 @@ import ( ) type DockerParameters interface { - DockerPrivileged() bool + DockerPrivileged() bool DockerKeepContainer() bool DockerShellCommands([]string) []string DockerMounts(string) ([]mount.Mount, error) DockerPorts() []int MakeRunnerDir() (string, error) + DockerPJRTDevice() string } type CLIArgs struct { @@ -40,6 +41,7 @@ type CLIArgs struct { KeepContainer bool ConcatinatedPublicSSHKeys string Privileged bool + PJRTDevice string } } diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index a9b49d55e4..988c37fbd9 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -120,14 +120,16 @@ def get_shim_env(build: str, authorized_keys: List[str]) -> Dict[str, str]: return envs -def get_shim_commands(authorized_keys: List[str], *, is_privileged: bool = False) -> List[str]: +def get_shim_commands( + authorized_keys: List[str], *, is_privileged: bool = False, pjrt_device: Optional[str] = None +) -> List[str]: build = get_dstack_runner_version() commands = get_shim_pre_start_commands( build, ) for k, v in get_shim_env(build, authorized_keys).items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script(is_privileged) + commands += get_run_shim_script(is_privileged, pjrt_device) return commands @@ -160,11 +162,13 @@ def get_shim_pre_start_commands(build: str) -> List[str]: ] -def get_run_shim_script(is_privileged: bool) -> List[str]: +def get_run_shim_script(is_privileged: bool, pjrt_device: Optional[str]) -> List[str]: dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" privileged_flag = "--privileged" if is_privileged else "" + pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else "" + return [ - f"nohup dstack-shim {dev_flag} docker --keep-container {privileged_flag} >{DSTACK_WORKING_DIR}/shim.log 2>&1 &", + f"nohup dstack-shim {dev_flag} docker --keep-container {privileged_flag} {pjrt_device_env} >{DSTACK_WORKING_DIR}/shim.log 2>&1 &", ] diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 83c4cd2196..62bfd8a66d 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -443,7 +443,9 @@ def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]: def _get_tpu_startup_script(authorized_keys: List[str]) -> str: - commands = get_shim_commands(authorized_keys=authorized_keys, is_privileged=True) + commands = get_shim_commands( + authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU" + ) startup_script = " ".join([" && ".join(commands)]) startup_script = "#! /bin/bash\n" + startup_script return startup_script From c4e0acb8cfa3a8897e83e9a6b3a7c7a5a986f0b0 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 26 Jun 2024 13:01:42 +0545 Subject: [PATCH 4/4] Fix conditionally set PJRT_DEVICE env variable to avoid empty string --- runner/internal/shim/docker.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 7391c22bc4..4d0cda007f 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -312,14 +312,18 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str return "", tracerr.Wrap(err) } + //Set the environment variables + envVars := []string{} + if dockerParams.DockerPJRTDevice() != "" { + envVars = append(envVars, fmt.Sprintf("PJRT_DEVICE=%s", dockerParams.DockerPJRTDevice())) + } + containerConfig := &container.Config{ Image: taskConfig.ImageName, Cmd: []string{strings.Join(dockerParams.DockerShellCommands(taskConfig.PublicKeys), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, ExposedPorts: exposePorts(dockerParams.DockerPorts()...), - Env: []string{ - fmt.Sprintf("PJRT_DEVICE=%s", dockerParams.DockerPJRTDevice()), - }, + Env: envVars, } hostConfig := &container.HostConfig{ Privileged: dockerParams.DockerPrivileged(),